@@ -93,6 +93,17 @@ struct FlattenInfo {
93
93
FlattenInfo (Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {};
94
94
};
95
95
96
+ static bool
97
+ setLoopComponents (Value *&TC, Value *&TripCount, BinaryOperator *&Increment,
98
+ SmallPtrSetImpl<Instruction *> &IterationInstructions) {
99
+ TripCount = TC;
100
+ IterationInstructions.insert (Increment);
101
+ LLVM_DEBUG (dbgs () << " Found Increment: " ; Increment->dump ());
102
+ LLVM_DEBUG (dbgs () << " Found trip count: " ; TripCount->dump ());
103
+ LLVM_DEBUG (dbgs () << " Successfully found all loop components\n " );
104
+ return true ;
105
+ }
106
+
96
107
// Finds the induction variable, increment and trip count for a simple loop that
97
108
// we can flatten.
98
109
static bool findLoopComponents (
@@ -164,49 +175,63 @@ static bool findLoopComponents(
164
175
return false ;
165
176
}
166
177
// The trip count is the RHS of the compare. If this doesn't match the trip
167
- // count computed by SCEV then this is either because the trip count variable
168
- // has been widened (then leave the trip count as it is) , or because it is a
169
- // constant and another transformation has changed the compare, e.g.
170
- // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1.
171
- TripCount = Compare->getOperand (1 );
178
+ // count computed by SCEV then this is because the trip count variable
179
+ // has been widened so the types don't match , or because it is a constant and
180
+ // another transformation has changed the compare ( e.g. icmp ult %inc,
181
+ // tripcount -> icmp ult %j, tripcount-1), or both .
182
+ Value *RHS = Compare->getOperand (1 );
172
183
const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount (L);
173
184
if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
174
185
LLVM_DEBUG (dbgs () << " Backedge-taken count is not predictable\n " );
175
186
return false ;
176
187
}
177
188
const SCEV *SCEVTripCount = SE->getTripCountFromExitCount (BackedgeTakenCount);
178
- if (SE->getSCEV (TripCount) != SCEVTripCount && !IsWidened) {
179
- ConstantInt *RHS = dyn_cast<ConstantInt>(TripCount);
180
- if (!RHS) {
181
- LLVM_DEBUG (dbgs () << " Could not find valid trip count\n " );
182
- return false ;
183
- }
184
- // The L->isCanonical check above ensures we only get here if the loop
185
- // increments by 1 on each iteration, so the RHS of the Compare is
186
- // tripcount-1 (i.e equivalent to the backedge taken count).
187
- assert (SE->getSCEV (RHS) == BackedgeTakenCount &&
188
- " Expected RHS of compare to be equal to the backedge taken count" );
189
- ConstantInt *One = ConstantInt::get (RHS->getType (), 1 );
190
- TripCount = ConstantInt::get (TripCount->getContext (),
191
- RHS->getValue () + One->getValue ());
192
- } else if (SE->getSCEV (TripCount) != SCEVTripCount) {
193
- auto *TripCountInst = dyn_cast<Instruction>(TripCount);
194
- if (!TripCountInst) {
195
- LLVM_DEBUG (dbgs () << " Could not find valid extended trip count\n " );
196
- return false ;
189
+ const SCEV *SCEVRHS = SE->getSCEV (RHS);
190
+ if (SCEVRHS == SCEVTripCount)
191
+ return setLoopComponents (RHS, TripCount, Increment, IterationInstructions);
192
+ ConstantInt *ConstantRHS = dyn_cast<ConstantInt>(RHS);
193
+ if (ConstantRHS) {
194
+ const SCEV *BackedgeTCExt = nullptr ;
195
+ if (IsWidened) {
196
+ const SCEV *SCEVTripCountExt;
197
+ // Find the extended backedge taken count and extended trip count using
198
+ // SCEV. One of these should now match the RHS of the compare.
199
+ BackedgeTCExt = SE->getZeroExtendExpr (BackedgeTakenCount, RHS->getType ());
200
+ SCEVTripCountExt = SE->getTripCountFromExitCount (BackedgeTCExt);
201
+ if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
202
+ LLVM_DEBUG (dbgs () << " Could not find valid trip count\n " );
203
+ return false ;
204
+ }
197
205
}
198
- if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
199
- SE->getSCEV (TripCountInst->getOperand (0 )) != SCEVTripCount) {
200
- LLVM_DEBUG (dbgs () << " Could not find valid extended trip count\n " );
201
- return false ;
206
+ // If the RHS of the compare is equal to the backedge taken count we need
207
+ // to add one to get the trip count.
208
+ if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
209
+ ConstantInt *One = ConstantInt::get (ConstantRHS->getType (), 1 );
210
+ Value *NewRHS = ConstantInt::get (
211
+ ConstantRHS->getContext (), ConstantRHS->getValue () + One->getValue ());
212
+ return setLoopComponents (NewRHS, TripCount, Increment,
213
+ IterationInstructions);
202
214
}
215
+ return setLoopComponents (RHS, TripCount, Increment, IterationInstructions);
203
216
}
204
- IterationInstructions.insert (Increment);
205
- LLVM_DEBUG (dbgs () << " Found increment: " ; Increment->dump ());
206
- LLVM_DEBUG (dbgs () << " Found trip count: " ; TripCount->dump ());
207
-
208
- LLVM_DEBUG (dbgs () << " Successfully found all loop components\n " );
209
- return true ;
217
+ // If the RHS isn't a constant then check that the reason it doesn't match
218
+ // the SCEV trip count is because the RHS is a ZExt or SExt instruction
219
+ // (and take the trip count to be the RHS).
220
+ if (!IsWidened) {
221
+ LLVM_DEBUG (dbgs () << " Could not find valid trip count\n " );
222
+ return false ;
223
+ }
224
+ auto *TripCountInst = dyn_cast<Instruction>(RHS);
225
+ if (!TripCountInst) {
226
+ LLVM_DEBUG (dbgs () << " Could not find valid trip count\n " );
227
+ return false ;
228
+ }
229
+ if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
230
+ SE->getSCEV (TripCountInst->getOperand (0 )) != SCEVTripCount) {
231
+ LLVM_DEBUG (dbgs () << " Could not find valid extended trip count\n " );
232
+ return false ;
233
+ }
234
+ return setLoopComponents (RHS, TripCount, Increment, IterationInstructions);
210
235
}
211
236
212
237
static bool checkPHIs (FlattenInfo &FI, const TargetTransformInfo *TTI) {
0 commit comments