Skip to content

Commit 80852a4

Browse files
committed
[SCEV] Prove implications of different type via truncation
When we need to prove implication of expressions of different type width, the default strategy is to widen everything to wider type and prove in this type. This does not interact well with AddRecs with negative steps and unsigned predicates: such AddRec will likely not have a `nuw` flag, and its `zext` to wider type will not be an AddRec. In contraty, `trunc` of an AddRec in some cases can easily be proved to be an `AddRec` too. This patch introduces an alternative way to handling implications of different type widths. If we can prove that wider type values actually fit in the narrow type, we truncate them and prove the implication in narrow type. Differential Revision: https://reviews.llvm.org/D89548 Reviewed By: fhahn
1 parent 79a69f5 commit 80852a4

File tree

3 files changed

+61
-1
lines changed

3 files changed

+61
-1
lines changed

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9699,6 +9699,25 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
96999699
// Balance the types.
97009700
if (getTypeSizeInBits(LHS->getType()) <
97019701
getTypeSizeInBits(FoundLHS->getType())) {
9702+
// For unsigned and equality predicates, try to prove that both found
9703+
// operands fit into narrow unsigned range. If so, try to prove facts in
9704+
// narrow types.
9705+
if (!CmpInst::isSigned(FoundPred)) {
9706+
auto *NarrowType = LHS->getType();
9707+
auto *WideType = FoundLHS->getType();
9708+
auto BitWidth = getTypeSizeInBits(NarrowType);
9709+
const SCEV *MaxValue = getZeroExtendExpr(
9710+
getConstant(APInt::getMaxValue(BitWidth)), WideType);
9711+
if (isKnownPredicate(ICmpInst::ICMP_ULE, FoundLHS, MaxValue) &&
9712+
isKnownPredicate(ICmpInst::ICMP_ULE, FoundRHS, MaxValue)) {
9713+
const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
9714+
const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
9715+
if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS,
9716+
TruncFoundRHS, Context))
9717+
return true;
9718+
}
9719+
}
9720+
97029721
if (CmpInst::isSigned(Pred)) {
97039722
LHS = getSignExtendExpr(LHS, FoundLHS->getType());
97049723
RHS = getSignExtendExpr(RHS, FoundLHS->getType());

llvm/test/Analysis/ScalarEvolution/srem.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ define dso_local void @_Z4loopi(i32 %width) local_unnamed_addr #0 {
2929
; CHECK-NEXT: %add = add nsw i32 %2, %call
3030
; CHECK-NEXT: --> (%2 + %call) U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.cond: Variant }
3131
; CHECK-NEXT: %inc = add nsw i32 %i.0, 1
32-
; CHECK-NEXT: --> {1,+,1}<nuw><%for.cond> U: [1,0) S: [1,0) Exits: (1 + %width) LoopDispositions: { %for.cond: Computable }
32+
; CHECK-NEXT: --> {1,+,1}<nuw><%for.cond> U: full-set S: full-set Exits: (1 + %width) LoopDispositions: { %for.cond: Computable }
3333
; CHECK-NEXT: Determining loop execution counts for: @_Z4loopi
3434
; CHECK-NEXT: Loop %for.cond: backedge-taken count is %width
3535
; CHECK-NEXT: Loop %for.cond: max backedge-taken count is -1

llvm/unittests/Analysis/ScalarEvolutionTest.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,4 +1316,45 @@ TEST_F(ScalarEvolutionsTest, UnsignedIsImpliedViaOperations) {
13161316
});
13171317
}
13181318

1319+
TEST_F(ScalarEvolutionsTest, ProveImplicationViaNarrowing) {
1320+
LLVMContext C;
1321+
SMDiagnostic Err;
1322+
std::unique_ptr<Module> M = parseAssemblyString(
1323+
"define i32 @foo(i32 %start, i32* %q) { "
1324+
"entry: "
1325+
" %wide.start = zext i32 %start to i64 "
1326+
" br label %loop "
1327+
"loop: "
1328+
" %wide.iv = phi i64 [%wide.start, %entry], [%wide.iv.next, %backedge] "
1329+
" %iv = phi i32 [%start, %entry], [%iv.next, %backedge] "
1330+
" %cond = icmp eq i64 %wide.iv, 0 "
1331+
" br i1 %cond, label %exit, label %backedge "
1332+
"backedge: "
1333+
" %iv.next = add i32 %iv, -1 "
1334+
" %index = zext i32 %iv.next to i64 "
1335+
" %load.addr = getelementptr i32, i32* %q, i64 %index "
1336+
" %stop = load i32, i32* %load.addr "
1337+
" %loop.cond = icmp eq i32 %stop, 0 "
1338+
" %wide.iv.next = add nsw i64 %wide.iv, -1 "
1339+
" br i1 %loop.cond, label %loop, label %failure "
1340+
"exit: "
1341+
" ret i32 0 "
1342+
"failure: "
1343+
" unreachable "
1344+
"} ",
1345+
Err, C);
1346+
1347+
ASSERT_TRUE(M && "Could not parse module?");
1348+
ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
1349+
1350+
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
1351+
auto *IV = SE.getSCEV(getInstructionByName(F, "iv"));
1352+
auto *Zero = SE.getZero(IV->getType());
1353+
auto *Backedge = getInstructionByName(F, "iv.next")->getParent();
1354+
ASSERT_TRUE(Backedge);
1355+
EXPECT_TRUE(SE.isBasicBlockEntryGuardedByCond(Backedge, ICmpInst::ICMP_UGT,
1356+
IV, Zero));
1357+
});
1358+
}
1359+
13191360
} // end namespace llvm

0 commit comments

Comments
 (0)