@@ -1249,6 +1249,11 @@ BasicBlock *ScopStmt::getEntryBlock() const {
1249
1249
1250
1250
unsigned ScopStmt::getNumIterators () const { return NestLoops.size (); }
1251
1251
1252
+ void ScopStmt::setInstructions (ArrayRef<Instruction *> Range) {
1253
+ getParent ()->updateInstStmtMap (Instructions, Range, this );
1254
+ Instructions.assign (Range.begin (), Range.end ());
1255
+ }
1256
+
1252
1257
const char *ScopStmt::getBaseName () const { return BaseName.c_str (); }
1253
1258
1254
1259
Loop *ScopStmt::getLoopForDimension (unsigned Dimension) const {
@@ -1728,8 +1733,10 @@ Scop::Scop(Region &R, ScalarEvolution &ScalarEvolution, LoopInfo &LI,
1728
1733
Scop::~Scop () = default ;
1729
1734
1730
1735
void Scop::removeFromStmtMap (ScopStmt &Stmt) {
1731
- for (Instruction *Inst : Stmt.getInstructions ())
1736
+ for (Instruction *Inst : Stmt.getInstructions ()) {
1737
+ assert (!InstStmtMap.count (Inst) || InstStmtMap.lookup (Inst) == &Stmt);
1732
1738
InstStmtMap.erase (Inst);
1739
+ }
1733
1740
1734
1741
if (Stmt.isRegionStmt ()) {
1735
1742
for (BasicBlock *BB : Stmt.getRegion ()->blocks ()) {
@@ -1738,18 +1745,27 @@ void Scop::removeFromStmtMap(ScopStmt &Stmt) {
1738
1745
// part of the statement's instruction list.
1739
1746
if (BB == Stmt.getEntryBlock ())
1740
1747
continue ;
1741
- for (Instruction &Inst : *BB)
1748
+ for (Instruction &Inst : *BB) {
1749
+ assert (!InstStmtMap.count (&Inst) || InstStmtMap.lookup (&Inst) == &Stmt);
1742
1750
InstStmtMap.erase (&Inst);
1751
+ }
1743
1752
}
1744
1753
} else {
1745
1754
auto StmtMapIt = StmtMap.find (Stmt.getBasicBlock ());
1746
1755
if (StmtMapIt != StmtMap.end ())
1747
1756
StmtMapIt->second .erase (std::remove (StmtMapIt->second .begin (),
1748
1757
StmtMapIt->second .end (), &Stmt),
1749
1758
StmtMapIt->second .end ());
1750
- for (Instruction *Inst : Stmt.getInstructions ())
1759
+ for (Instruction *Inst : Stmt.getInstructions ()) {
1760
+ assert (!InstStmtMap.count (Inst) || InstStmtMap.lookup (Inst) == &Stmt);
1751
1761
InstStmtMap.erase (Inst);
1762
+ }
1752
1763
}
1764
+
1765
+ #ifndef NDEBUG
1766
+ for (auto kv : InstStmtMap)
1767
+ assert (kv.getSecond () != &Stmt);
1768
+ #endif
1753
1769
}
1754
1770
1755
1771
void Scop::removeStmts (std::function<bool (ScopStmt &)> ShouldDelete,
@@ -2471,6 +2487,19 @@ ArrayRef<ScopStmt *> Scop::getStmtListFor(Region *R) const {
2471
2487
return getStmtListFor (R->getEntry ());
2472
2488
}
2473
2489
2490
+ void Scop::updateInstStmtMap (ArrayRef<Instruction *> OldList,
2491
+ ArrayRef<Instruction *> NewList, ScopStmt *Stmt) {
2492
+ for (Instruction *OldInst : OldList) {
2493
+ assert (getStmtFor (OldInst) == Stmt);
2494
+ InstStmtMap.erase (OldInst);
2495
+ }
2496
+
2497
+ for (Instruction *NewInst : NewList) {
2498
+ assert (InstStmtMap.lookup (NewInst) == nullptr );
2499
+ InstStmtMap[NewInst] = Stmt;
2500
+ }
2501
+ }
2502
+
2474
2503
int Scop::getRelativeLoopDepth (const Loop *L) const {
2475
2504
if (!L || !R.contains (L))
2476
2505
return -1 ;
0 commit comments