Skip to content

Commit 8e06bf6

Browse files
committed
[Polly] Ensure consistent Scop::InstStmtMap. NFC.
InstStmtMap became inconsistent with ScopStmt::getInstructions() after the statement's instructions is modified, e.g. by being considered unused by the Simplify pass or being moved by ForwardOpTree. Change ScopStmt::setInstructions() to also update its parent's InstStmtMap. Also add assertions checking the consistency.
1 parent 6983741 commit 8e06bf6

File tree

2 files changed

+40
-7
lines changed

2 files changed

+40
-7
lines changed

polly/include/polly/ScopInfo.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,9 +1539,7 @@ class ScopStmt {
15391539

15401540
/// Set the list of instructions for this statement. It replaces the current
15411541
/// list.
1542-
void setInstructions(ArrayRef<Instruction *> Range) {
1543-
Instructions.assign(Range.begin(), Range.end());
1544-
}
1542+
void setInstructions(ArrayRef<Instruction *> Range);
15451543

15461544
std::vector<Instruction *>::const_iterator insts_begin() const {
15471545
return Instructions.begin();
@@ -1949,7 +1947,7 @@ class Scop {
19491947
void addScopStmt(Region *R, StringRef Name, Loop *SurroundingLoop,
19501948
std::vector<Instruction *> EntryBlockInstructions);
19511949

1952-
/// Removes @p Stmt from the StmtMap.
1950+
/// Removes @p Stmt from the StmtMap and InstStmtMap.
19531951
void removeFromStmtMap(ScopStmt &Stmt);
19541952

19551953
/// Removes all statements where the entry block of the statement does not
@@ -2362,6 +2360,12 @@ class Scop {
23622360
return InstStmtMap.lookup(Inst);
23632361
}
23642362

2363+
/// Update the content of InstStmtMap for @p Stmt. @p OldList contains the
2364+
/// previous instructions in @p Stmt and is updated to contain the
2365+
/// instructions in @p NewList.
2366+
void updateInstStmtMap(ArrayRef<Instruction *> OldList,
2367+
ArrayRef<Instruction *> NewList, ScopStmt *Stmt);
2368+
23652369
/// Return the number of statements in the SCoP.
23662370
size_t getSize() const { return Stmts.size(); }
23672371

polly/lib/Analysis/ScopInfo.cpp

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,6 +1249,11 @@ BasicBlock *ScopStmt::getEntryBlock() const {
12491249

12501250
unsigned ScopStmt::getNumIterators() const { return NestLoops.size(); }
12511251

1252+
void ScopStmt::setInstructions(ArrayRef<Instruction *> Range) {
1253+
getParent()->updateInstStmtMap(Instructions, Range, this);
1254+
Instructions.assign(Range.begin(), Range.end());
1255+
}
1256+
12521257
const char *ScopStmt::getBaseName() const { return BaseName.c_str(); }
12531258

12541259
Loop *ScopStmt::getLoopForDimension(unsigned Dimension) const {
@@ -1728,8 +1733,10 @@ Scop::Scop(Region &R, ScalarEvolution &ScalarEvolution, LoopInfo &LI,
17281733
Scop::~Scop() = default;
17291734

17301735
void Scop::removeFromStmtMap(ScopStmt &Stmt) {
1731-
for (Instruction *Inst : Stmt.getInstructions())
1736+
for (Instruction *Inst : Stmt.getInstructions()) {
1737+
assert(!InstStmtMap.count(Inst) || InstStmtMap.lookup(Inst) == &Stmt);
17321738
InstStmtMap.erase(Inst);
1739+
}
17331740

17341741
if (Stmt.isRegionStmt()) {
17351742
for (BasicBlock *BB : Stmt.getRegion()->blocks()) {
@@ -1738,18 +1745,27 @@ void Scop::removeFromStmtMap(ScopStmt &Stmt) {
17381745
// part of the statement's instruction list.
17391746
if (BB == Stmt.getEntryBlock())
17401747
continue;
1741-
for (Instruction &Inst : *BB)
1748+
for (Instruction &Inst : *BB) {
1749+
assert(!InstStmtMap.count(&Inst) || InstStmtMap.lookup(&Inst) == &Stmt);
17421750
InstStmtMap.erase(&Inst);
1751+
}
17431752
}
17441753
} else {
17451754
auto StmtMapIt = StmtMap.find(Stmt.getBasicBlock());
17461755
if (StmtMapIt != StmtMap.end())
17471756
StmtMapIt->second.erase(std::remove(StmtMapIt->second.begin(),
17481757
StmtMapIt->second.end(), &Stmt),
17491758
StmtMapIt->second.end());
1750-
for (Instruction *Inst : Stmt.getInstructions())
1759+
for (Instruction *Inst : Stmt.getInstructions()) {
1760+
assert(!InstStmtMap.count(Inst) || InstStmtMap.lookup(Inst) == &Stmt);
17511761
InstStmtMap.erase(Inst);
1762+
}
17521763
}
1764+
1765+
#ifndef NDEBUG
1766+
for (auto kv : InstStmtMap)
1767+
assert(kv.getSecond() != &Stmt);
1768+
#endif
17531769
}
17541770

17551771
void Scop::removeStmts(std::function<bool(ScopStmt &)> ShouldDelete,
@@ -2471,6 +2487,19 @@ ArrayRef<ScopStmt *> Scop::getStmtListFor(Region *R) const {
24712487
return getStmtListFor(R->getEntry());
24722488
}
24732489

2490+
void Scop::updateInstStmtMap(ArrayRef<Instruction *> OldList,
2491+
ArrayRef<Instruction *> NewList, ScopStmt *Stmt) {
2492+
for (Instruction *OldInst : OldList) {
2493+
assert(getStmtFor(OldInst) == Stmt);
2494+
InstStmtMap.erase(OldInst);
2495+
}
2496+
2497+
for (Instruction *NewInst : NewList) {
2498+
assert(InstStmtMap.lookup(NewInst) == nullptr);
2499+
InstStmtMap[NewInst] = Stmt;
2500+
}
2501+
}
2502+
24742503
int Scop::getRelativeLoopDepth(const Loop *L) const {
24752504
if (!L || !R.contains(L))
24762505
return -1;

0 commit comments

Comments
 (0)