Skip to content

Commit 46abd1f

Browse files
committed
[LoopFlatten] Fix assertion failure in checkOverflow
There is an assertion failure in computeOverflowForUnsignedMul (used in checkOverflow) due to the inner and outer trip counts having different types. This occurs when the IV has been widened, but the loop components are not successfully rediscovered. This is fixed by some refactoring of the code in findLoopComponents which identifies the trip count of the loop.
1 parent c064ba3 commit 46abd1f

File tree

2 files changed

+105
-34
lines changed

2 files changed

+105
-34
lines changed

llvm/lib/Transforms/Scalar/LoopFlatten.cpp

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,17 @@ struct FlattenInfo {
9393
FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {};
9494
};
9595

96+
static bool
97+
setLoopComponents(Value *&TC, Value *&TripCount, BinaryOperator *&Increment,
98+
SmallPtrSetImpl<Instruction *> &IterationInstructions) {
99+
TripCount = TC;
100+
IterationInstructions.insert(Increment);
101+
LLVM_DEBUG(dbgs() << "Found Increment: "; Increment->dump());
102+
LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump());
103+
LLVM_DEBUG(dbgs() << "Successfully found all loop components\n");
104+
return true;
105+
}
106+
96107
// Finds the induction variable, increment and trip count for a simple loop that
97108
// we can flatten.
98109
static bool findLoopComponents(
@@ -164,49 +175,63 @@ static bool findLoopComponents(
164175
return false;
165176
}
166177
// The trip count is the RHS of the compare. If this doesn't match the trip
167-
// count computed by SCEV then this is either because the trip count variable
168-
// has been widened (then leave the trip count as it is), or because it is a
169-
// constant and another transformation has changed the compare, e.g.
170-
// icmp ult %inc, tripcount -> icmp ult %j, tripcount-1.
171-
TripCount = Compare->getOperand(1);
178+
// count computed by SCEV then this is because the trip count variable
179+
// has been widened so the types don't match, or because it is a constant and
180+
// another transformation has changed the compare (e.g. icmp ult %inc,
181+
// tripcount -> icmp ult %j, tripcount-1), or both.
182+
Value *RHS = Compare->getOperand(1);
172183
const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L);
173184
if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
174185
LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n");
175186
return false;
176187
}
177188
const SCEV *SCEVTripCount = SE->getTripCountFromExitCount(BackedgeTakenCount);
178-
if (SE->getSCEV(TripCount) != SCEVTripCount && !IsWidened) {
179-
ConstantInt *RHS = dyn_cast<ConstantInt>(TripCount);
180-
if (!RHS) {
181-
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
182-
return false;
183-
}
184-
// The L->isCanonical check above ensures we only get here if the loop
185-
// increments by 1 on each iteration, so the RHS of the Compare is
186-
// tripcount-1 (i.e equivalent to the backedge taken count).
187-
assert(SE->getSCEV(RHS) == BackedgeTakenCount &&
188-
"Expected RHS of compare to be equal to the backedge taken count");
189-
ConstantInt *One = ConstantInt::get(RHS->getType(), 1);
190-
TripCount = ConstantInt::get(TripCount->getContext(),
191-
RHS->getValue() + One->getValue());
192-
} else if (SE->getSCEV(TripCount) != SCEVTripCount) {
193-
auto *TripCountInst = dyn_cast<Instruction>(TripCount);
194-
if (!TripCountInst) {
195-
LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
196-
return false;
189+
const SCEV *SCEVRHS = SE->getSCEV(RHS);
190+
if (SCEVRHS == SCEVTripCount)
191+
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
192+
ConstantInt *ConstantRHS = dyn_cast<ConstantInt>(RHS);
193+
if (ConstantRHS) {
194+
const SCEV *BackedgeTCExt = nullptr;
195+
if (IsWidened) {
196+
const SCEV *SCEVTripCountExt;
197+
// Find the extended backedge taken count and extended trip count using
198+
// SCEV. One of these should now match the RHS of the compare.
199+
BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType());
200+
SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt);
201+
if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
202+
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
203+
return false;
204+
}
197205
}
198-
if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
199-
SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
200-
LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
201-
return false;
206+
// If the RHS of the compare is equal to the backedge taken count we need
207+
// to add one to get the trip count.
208+
if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
209+
ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1);
210+
Value *NewRHS = ConstantInt::get(
211+
ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue());
212+
return setLoopComponents(NewRHS, TripCount, Increment,
213+
IterationInstructions);
202214
}
215+
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
203216
}
204-
IterationInstructions.insert(Increment);
205-
LLVM_DEBUG(dbgs() << "Found increment: "; Increment->dump());
206-
LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump());
207-
208-
LLVM_DEBUG(dbgs() << "Successfully found all loop components\n");
209-
return true;
217+
// If the RHS isn't a constant then check that the reason it doesn't match
218+
// the SCEV trip count is because the RHS is a ZExt or SExt instruction
219+
// (and take the trip count to be the RHS).
220+
if (!IsWidened) {
221+
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
222+
return false;
223+
}
224+
auto *TripCountInst = dyn_cast<Instruction>(RHS);
225+
if (!TripCountInst) {
226+
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
227+
return false;
228+
}
229+
if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
230+
SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
231+
LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
232+
return false;
233+
}
234+
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
210235
}
211236

212237
static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) {

llvm/test/Transforms/LoopFlatten/widen-iv.ll

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,52 @@ for.cond.cleanup:
525525
ret void
526526
}
527527

528+
; Identify trip count when it is constant and the IV has been widened.
529+
define i32 @constTripCount() {
530+
; CHECK-LABEL: @constTripCount(
531+
; CHECK-NEXT: entry:
532+
; CHECK-NEXT: [[FLATTEN_TRIPCOUNT:%.*]] = mul i64 20, 20
533+
; CHECK-NEXT: br label [[I_LOOP:%.*]]
534+
; CHECK: i.loop:
535+
; CHECK-NEXT: [[INDVAR1:%.*]] = phi i64 [ [[INDVAR_NEXT2:%.*]], [[J_LOOPDONE:%.*]] ], [ 0, [[ENTRY:%.*]] ]
536+
; CHECK-NEXT: br label [[J_LOOP:%.*]]
537+
; CHECK: j.loop:
538+
; CHECK-NEXT: [[INDVAR:%.*]] = phi i64 [ 0, [[I_LOOP]] ]
539+
; CHECK-NEXT: call void @payload()
540+
; CHECK-NEXT: [[INDVAR_NEXT:%.*]] = add i64 [[INDVAR]], 1
541+
; CHECK-NEXT: [[J_ATEND:%.*]] = icmp eq i64 [[INDVAR_NEXT]], 20
542+
; CHECK-NEXT: br label [[J_LOOPDONE]]
543+
; CHECK: j.loopdone:
544+
; CHECK-NEXT: [[INDVAR_NEXT2]] = add i64 [[INDVAR1]], 1
545+
; CHECK-NEXT: [[I_ATEND:%.*]] = icmp eq i64 [[INDVAR_NEXT2]], [[FLATTEN_TRIPCOUNT]]
546+
; CHECK-NEXT: br i1 [[I_ATEND]], label [[I_LOOPDONE:%.*]], label [[I_LOOP]]
547+
; CHECK: i.loopdone:
548+
; CHECK-NEXT: ret i32 0
549+
;
550+
entry:
551+
br label %i.loop
552+
553+
i.loop:
554+
%i = phi i8 [ 0, %entry ], [ %i.inc, %j.loopdone ]
555+
br label %j.loop
556+
557+
j.loop:
558+
%j = phi i8 [ 0, %i.loop ], [ %j.inc, %j.loop ]
559+
call void @payload()
560+
%j.inc = add i8 %j, 1
561+
%j.atend = icmp eq i8 %j.inc, 20
562+
br i1 %j.atend, label %j.loopdone, label %j.loop
563+
564+
j.loopdone:
565+
%i.inc = add i8 %i, 1
566+
%i.atend = icmp eq i8 %i.inc, 20
567+
br i1 %i.atend, label %i.loopdone, label %i.loop
568+
569+
i.loopdone:
570+
ret i32 0
571+
}
572+
573+
declare void @payload()
528574
declare dso_local i32 @use_32(i32)
529575
declare dso_local i32 @use_16(i16)
530576
declare dso_local i32 @use_64(i64)

0 commit comments

Comments
 (0)