Skip to content

Commit 734ee0e

Browse files
authored
[LVI] Support using block values when handling conditions (#75311)
Currently, LVI will only use conditions like "X < C" to constrain the value of X on the relevant edge. This patch extends it to handle conditions like "X < Y" by querying the known range of Y. This means that getValueFromCondition() and various related APIs can now return nullopt to indicate that they have pushed to the worklist, and need to be called again later. This behavior is currently controlled by a UseBlockValue option, and only enabled for actual edge value handling. All other places deriving constraints from conditions keep using the previous logic for now. This change was originally motivated as a fix for the regression reported in #73662 (comment). Unfortunately, it doesn't actually fix it, because we run into another issue there (LVI currently is really bad at handling values used in loops). This change has some compile-time impact, but it's fairly small, in the 0.05% range.
1 parent 5055eee commit 734ee0e

File tree

2 files changed

+110
-53
lines changed

2 files changed

+110
-53
lines changed

llvm/lib/Analysis/LazyValueInfo.cpp

Lines changed: 107 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,28 @@ class LazyValueInfoImpl {
434434

435435
void solve();
436436

437+
// For the following methods, if UseBlockValue is true, the function may
438+
// push additional values to the worklist and return nullopt. If
439+
// UseBlockValue is false, it will never return nullopt.
440+
441+
std::optional<ValueLatticeElement>
442+
getValueFromSimpleICmpCondition(CmpInst::Predicate Pred, Value *RHS,
443+
const APInt &Offset, Instruction *CxtI,
444+
bool UseBlockValue);
445+
446+
std::optional<ValueLatticeElement>
447+
getValueFromICmpCondition(Value *Val, ICmpInst *ICI, bool isTrueDest,
448+
bool UseBlockValue);
449+
450+
std::optional<ValueLatticeElement>
451+
getValueFromCondition(Value *Val, Value *Cond, bool IsTrueDest,
452+
bool UseBlockValue, unsigned Depth = 0);
453+
454+
std::optional<ValueLatticeElement> getEdgeValueLocal(Value *Val,
455+
BasicBlock *BBFrom,
456+
BasicBlock *BBTo,
457+
bool UseBlockValue);
458+
437459
public:
438460
/// This is the query interface to determine the lattice value for the
439461
/// specified Value* at the context instruction (if specified) or at the
@@ -755,14 +777,10 @@ LazyValueInfoImpl::solveBlockValuePHINode(PHINode *PN, BasicBlock *BB) {
755777
return Result;
756778
}
757779

758-
static ValueLatticeElement getValueFromCondition(Value *Val, Value *Cond,
759-
bool isTrueDest = true,
760-
unsigned Depth = 0);
761-
762780
// If we can determine a constraint on the value given conditions assumed by
763781
// the program, intersect those constraints with BBLV
764782
void LazyValueInfoImpl::intersectAssumeOrGuardBlockValueConstantRange(
765-
Value *Val, ValueLatticeElement &BBLV, Instruction *BBI) {
783+
Value *Val, ValueLatticeElement &BBLV, Instruction *BBI) {
766784
BBI = BBI ? BBI : dyn_cast<Instruction>(Val);
767785
if (!BBI)
768786
return;
@@ -779,17 +797,21 @@ void LazyValueInfoImpl::intersectAssumeOrGuardBlockValueConstantRange(
779797
if (I->getParent() != BB || !isValidAssumeForContext(I, BBI))
780798
continue;
781799

782-
BBLV = intersect(BBLV, getValueFromCondition(Val, I->getArgOperand(0)));
800+
BBLV = intersect(BBLV, *getValueFromCondition(Val, I->getArgOperand(0),
801+
/*IsTrueDest*/ true,
802+
/*UseBlockValue*/ false));
783803
}
784804

785805
// If guards are not used in the module, don't spend time looking for them
786806
if (GuardDecl && !GuardDecl->use_empty() &&
787807
BBI->getIterator() != BB->begin()) {
788-
for (Instruction &I : make_range(std::next(BBI->getIterator().getReverse()),
789-
BB->rend())) {
808+
for (Instruction &I :
809+
make_range(std::next(BBI->getIterator().getReverse()), BB->rend())) {
790810
Value *Cond = nullptr;
791811
if (match(&I, m_Intrinsic<Intrinsic::experimental_guard>(m_Value(Cond))))
792-
BBLV = intersect(BBLV, getValueFromCondition(Val, Cond));
812+
BBLV = intersect(BBLV,
813+
*getValueFromCondition(Val, Cond, /*IsTrueDest*/ true,
814+
/*UseBlockValue*/ false));
793815
}
794816
}
795817

@@ -886,10 +908,14 @@ LazyValueInfoImpl::solveBlockValueSelect(SelectInst *SI, BasicBlock *BB) {
886908
// If the value is undef, a different value may be chosen in
887909
// the select condition.
888910
if (isGuaranteedNotToBeUndef(Cond, AC)) {
889-
TrueVal = intersect(TrueVal,
890-
getValueFromCondition(SI->getTrueValue(), Cond, true));
891-
FalseVal = intersect(
892-
FalseVal, getValueFromCondition(SI->getFalseValue(), Cond, false));
911+
TrueVal =
912+
intersect(TrueVal, *getValueFromCondition(SI->getTrueValue(), Cond,
913+
/*IsTrueDest*/ true,
914+
/*UseBlockValue*/ false));
915+
FalseVal =
916+
intersect(FalseVal, *getValueFromCondition(SI->getFalseValue(), Cond,
917+
/*IsTrueDest*/ false,
918+
/*UseBlockValue*/ false));
893919
}
894920

895921
ValueLatticeElement Result = TrueVal;
@@ -1068,15 +1094,26 @@ static bool matchICmpOperand(APInt &Offset, Value *LHS, Value *Val,
10681094
}
10691095

10701096
/// Get value range for a "(Val + Offset) Pred RHS" condition.
1071-
static ValueLatticeElement getValueFromSimpleICmpCondition(
1072-
CmpInst::Predicate Pred, Value *RHS, const APInt &Offset) {
1097+
std::optional<ValueLatticeElement>
1098+
LazyValueInfoImpl::getValueFromSimpleICmpCondition(CmpInst::Predicate Pred,
1099+
Value *RHS,
1100+
const APInt &Offset,
1101+
Instruction *CxtI,
1102+
bool UseBlockValue) {
10731103
ConstantRange RHSRange(RHS->getType()->getIntegerBitWidth(),
10741104
/*isFullSet=*/true);
1075-
if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS))
1105+
if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
10761106
RHSRange = ConstantRange(CI->getValue());
1077-
else if (Instruction *I = dyn_cast<Instruction>(RHS))
1107+
} else if (UseBlockValue) {
1108+
std::optional<ValueLatticeElement> R =
1109+
getBlockValue(RHS, CxtI->getParent(), CxtI);
1110+
if (!R)
1111+
return std::nullopt;
1112+
RHSRange = toConstantRange(*R, RHS->getType());
1113+
} else if (Instruction *I = dyn_cast<Instruction>(RHS)) {
10781114
if (auto *Ranges = I->getMetadata(LLVMContext::MD_range))
10791115
RHSRange = getConstantRangeFromMetadata(*Ranges);
1116+
}
10801117

10811118
ConstantRange TrueValues =
10821119
ConstantRange::makeAllowedICmpRegion(Pred, RHSRange);
@@ -1103,8 +1140,8 @@ getRangeViaSLT(CmpInst::Predicate Pred, APInt RHS,
11031140
return std::nullopt;
11041141
}
11051142

1106-
static ValueLatticeElement getValueFromICmpCondition(Value *Val, ICmpInst *ICI,
1107-
bool isTrueDest) {
1143+
std::optional<ValueLatticeElement> LazyValueInfoImpl::getValueFromICmpCondition(
1144+
Value *Val, ICmpInst *ICI, bool isTrueDest, bool UseBlockValue) {
11081145
Value *LHS = ICI->getOperand(0);
11091146
Value *RHS = ICI->getOperand(1);
11101147

@@ -1128,11 +1165,13 @@ static ValueLatticeElement getValueFromICmpCondition(Value *Val, ICmpInst *ICI,
11281165
unsigned BitWidth = Ty->getScalarSizeInBits();
11291166
APInt Offset(BitWidth, 0);
11301167
if (matchICmpOperand(Offset, LHS, Val, EdgePred))
1131-
return getValueFromSimpleICmpCondition(EdgePred, RHS, Offset);
1168+
return getValueFromSimpleICmpCondition(EdgePred, RHS, Offset, ICI,
1169+
UseBlockValue);
11321170

11331171
CmpInst::Predicate SwappedPred = CmpInst::getSwappedPredicate(EdgePred);
11341172
if (matchICmpOperand(Offset, RHS, Val, SwappedPred))
1135-
return getValueFromSimpleICmpCondition(SwappedPred, LHS, Offset);
1173+
return getValueFromSimpleICmpCondition(SwappedPred, LHS, Offset, ICI,
1174+
UseBlockValue);
11361175

11371176
const APInt *Mask, *C;
11381177
if (match(LHS, m_And(m_Specific(Val), m_APInt(Mask))) &&
@@ -1212,10 +1251,12 @@ static ValueLatticeElement getValueFromOverflowCondition(
12121251
return ValueLatticeElement::getRange(NWR);
12131252
}
12141253

1215-
static ValueLatticeElement getValueFromCondition(
1216-
Value *Val, Value *Cond, bool IsTrueDest, unsigned Depth) {
1254+
std::optional<ValueLatticeElement>
1255+
LazyValueInfoImpl::getValueFromCondition(Value *Val, Value *Cond,
1256+
bool IsTrueDest, bool UseBlockValue,
1257+
unsigned Depth) {
12171258
if (ICmpInst *ICI = dyn_cast<ICmpInst>(Cond))
1218-
return getValueFromICmpCondition(Val, ICI, IsTrueDest);
1259+
return getValueFromICmpCondition(Val, ICI, IsTrueDest, UseBlockValue);
12191260

12201261
if (auto *EVI = dyn_cast<ExtractValueInst>(Cond))
12211262
if (auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand()))
@@ -1227,7 +1268,7 @@ static ValueLatticeElement getValueFromCondition(
12271268

12281269
Value *N;
12291270
if (match(Cond, m_Not(m_Value(N))))
1230-
return getValueFromCondition(Val, N, !IsTrueDest, Depth);
1271+
return getValueFromCondition(Val, N, !IsTrueDest, UseBlockValue, Depth);
12311272

12321273
Value *L, *R;
12331274
bool IsAnd;
@@ -1238,19 +1279,23 @@ static ValueLatticeElement getValueFromCondition(
12381279
else
12391280
return ValueLatticeElement::getOverdefined();
12401281

1241-
ValueLatticeElement LV = getValueFromCondition(Val, L, IsTrueDest, Depth);
1242-
ValueLatticeElement RV = getValueFromCondition(Val, R, IsTrueDest, Depth);
1282+
std::optional<ValueLatticeElement> LV =
1283+
getValueFromCondition(Val, L, IsTrueDest, UseBlockValue, Depth);
1284+
std::optional<ValueLatticeElement> RV =
1285+
getValueFromCondition(Val, R, IsTrueDest, UseBlockValue, Depth);
1286+
if (!LV || !RV)
1287+
return std::nullopt;
12431288

12441289
// if (L && R) -> intersect L and R
12451290
// if (!(L || R)) -> intersect !L and !R
12461291
// if (L || R) -> union L and R
12471292
// if (!(L && R)) -> union !L and !R
12481293
if (IsTrueDest ^ IsAnd) {
1249-
LV.mergeIn(RV);
1250-
return LV;
1294+
LV->mergeIn(*RV);
1295+
return *LV;
12511296
}
12521297

1253-
return intersect(LV, RV);
1298+
return intersect(*LV, *RV);
12541299
}
12551300

12561301
// Return true if Usr has Op as an operand, otherwise false.
@@ -1302,8 +1347,9 @@ static ValueLatticeElement constantFoldUser(User *Usr, Value *Op,
13021347
}
13031348

13041349
/// Compute the value of Val on the edge BBFrom -> BBTo.
1305-
static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
1306-
BasicBlock *BBTo) {
1350+
std::optional<ValueLatticeElement>
1351+
LazyValueInfoImpl::getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
1352+
BasicBlock *BBTo, bool UseBlockValue) {
13071353
// TODO: Handle more complex conditionals. If (v == 0 || v2 < 1) is false, we
13081354
// know that v != 0.
13091355
if (BranchInst *BI = dyn_cast<BranchInst>(BBFrom->getTerminator())) {
@@ -1324,13 +1370,16 @@ static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
13241370

13251371
// If the condition of the branch is an equality comparison, we may be
13261372
// able to infer the value.
1327-
ValueLatticeElement Result = getValueFromCondition(Val, Condition,
1328-
isTrueDest);
1329-
if (!Result.isOverdefined())
1373+
std::optional<ValueLatticeElement> Result =
1374+
getValueFromCondition(Val, Condition, isTrueDest, UseBlockValue);
1375+
if (!Result)
1376+
return std::nullopt;
1377+
1378+
if (!Result->isOverdefined())
13301379
return Result;
13311380

13321381
if (User *Usr = dyn_cast<User>(Val)) {
1333-
assert(Result.isOverdefined() && "Result isn't overdefined");
1382+
assert(Result->isOverdefined() && "Result isn't overdefined");
13341383
// Check with isOperationFoldable() first to avoid linearly iterating
13351384
// over the operands unnecessarily which can be expensive for
13361385
// instructions with many operands.
@@ -1356,8 +1405,8 @@ static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
13561405
// br i1 %Condition, label %then, label %else
13571406
for (unsigned i = 0; i < Usr->getNumOperands(); ++i) {
13581407
Value *Op = Usr->getOperand(i);
1359-
ValueLatticeElement OpLatticeVal =
1360-
getValueFromCondition(Op, Condition, isTrueDest);
1408+
ValueLatticeElement OpLatticeVal = *getValueFromCondition(
1409+
Op, Condition, isTrueDest, /*UseBlockValue*/ false);
13611410
if (std::optional<APInt> OpConst =
13621411
OpLatticeVal.asConstantInteger()) {
13631412
Result = constantFoldUser(Usr, Op, *OpConst, DL);
@@ -1367,7 +1416,7 @@ static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
13671416
}
13681417
}
13691418
}
1370-
if (!Result.isOverdefined())
1419+
if (!Result->isOverdefined())
13711420
return Result;
13721421
}
13731422
}
@@ -1432,8 +1481,12 @@ LazyValueInfoImpl::getEdgeValue(Value *Val, BasicBlock *BBFrom,
14321481
if (Constant *VC = dyn_cast<Constant>(Val))
14331482
return ValueLatticeElement::get(VC);
14341483

1435-
ValueLatticeElement LocalResult = getEdgeValueLocal(Val, BBFrom, BBTo);
1436-
if (hasSingleValue(LocalResult))
1484+
std::optional<ValueLatticeElement> LocalResult =
1485+
getEdgeValueLocal(Val, BBFrom, BBTo, /*UseBlockValue*/ true);
1486+
if (!LocalResult)
1487+
return std::nullopt;
1488+
1489+
if (hasSingleValue(*LocalResult))
14371490
// Can't get any more precise here
14381491
return LocalResult;
14391492

@@ -1453,7 +1506,7 @@ LazyValueInfoImpl::getEdgeValue(Value *Val, BasicBlock *BBFrom,
14531506
// but then the result is not cached.
14541507
intersectAssumeOrGuardBlockValueConstantRange(Val, InBlock, CxtI);
14551508

1456-
return intersect(LocalResult, InBlock);
1509+
return intersect(*LocalResult, InBlock);
14571510
}
14581511

14591512
ValueLatticeElement LazyValueInfoImpl::getValueInBlock(Value *V, BasicBlock *BB,
@@ -1499,10 +1552,12 @@ getValueOnEdge(Value *V, BasicBlock *FromBB, BasicBlock *ToBB,
14991552

15001553
std::optional<ValueLatticeElement> Result =
15011554
getEdgeValue(V, FromBB, ToBB, CxtI);
1502-
if (!Result) {
1555+
while (!Result) {
1556+
// As the worklist only explicitly tracks block values (but not edge values)
1557+
// we may have to call solve() multiple times, as the edge value calculation
1558+
// may request additional block values.
15031559
solve();
15041560
Result = getEdgeValue(V, FromBB, ToBB, CxtI);
1505-
assert(Result && "More work to do after problem solved?");
15061561
}
15071562

15081563
LLVM_DEBUG(dbgs() << " Result = " << *Result << "\n");
@@ -1528,13 +1583,17 @@ ValueLatticeElement LazyValueInfoImpl::getValueAtUse(const Use &U) {
15281583
if (!isGuaranteedNotToBeUndef(SI->getCondition(), AC))
15291584
break;
15301585
if (CurrU->getOperandNo() == 1)
1531-
CondVal = getValueFromCondition(V, SI->getCondition(), true);
1586+
CondVal =
1587+
*getValueFromCondition(V, SI->getCondition(), /*IsTrueDest*/ true,
1588+
/*UseBlockValue*/ false);
15321589
else if (CurrU->getOperandNo() == 2)
1533-
CondVal = getValueFromCondition(V, SI->getCondition(), false);
1590+
CondVal =
1591+
*getValueFromCondition(V, SI->getCondition(), /*IsTrueDest*/ false,
1592+
/*UseBlockValue*/ false);
15341593
} else if (auto *PHI = dyn_cast<PHINode>(CurrI)) {
15351594
// TODO: Use non-local query?
1536-
CondVal =
1537-
getEdgeValueLocal(V, PHI->getIncomingBlock(*CurrU), PHI->getParent());
1595+
CondVal = *getEdgeValueLocal(V, PHI->getIncomingBlock(*CurrU),
1596+
PHI->getParent(), /*UseBlockValue*/ false);
15381597
}
15391598
if (CondVal)
15401599
VL = intersect(VL, *CondVal);

llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ define void @test_icmp_from_implied_cond(i32 %a, i32 %b) {
1212
; CHECK-NEXT: [[COND:%.*]] = icmp ult i32 [[B]], [[A]]
1313
; CHECK-NEXT: br i1 [[COND]], label [[L2:%.*]], label [[END]]
1414
; CHECK: l2:
15-
; CHECK-NEXT: [[B_CMP1:%.*]] = icmp ult i32 [[B]], 32
16-
; CHECK-NEXT: call void @use(i1 [[B_CMP1]])
15+
; CHECK-NEXT: call void @use(i1 true)
1716
; CHECK-NEXT: [[B_CMP2:%.*]] = icmp ult i32 [[B]], 31
1817
; CHECK-NEXT: call void @use(i1 [[B_CMP2]])
1918
; CHECK-NEXT: ret void
@@ -47,7 +46,7 @@ define i64 @test_sext_from_implied_cond(i32 %a, i32 %b) {
4746
; CHECK-NEXT: [[COND:%.*]] = icmp ult i32 [[B]], [[A]]
4847
; CHECK-NEXT: br i1 [[COND]], label [[L2:%.*]], label [[END]]
4948
; CHECK: l2:
50-
; CHECK-NEXT: [[SEXT:%.*]] = sext i32 [[B]] to i64
49+
; CHECK-NEXT: [[SEXT:%.*]] = zext nneg i32 [[B]] to i64
5150
; CHECK-NEXT: ret i64 [[SEXT]]
5251
; CHECK: end:
5352
; CHECK-NEXT: ret i64 0
@@ -74,8 +73,7 @@ define void @test_icmp_from_implied_range(i16 %x, i32 %b) {
7473
; CHECK-NEXT: [[COND:%.*]] = icmp ult i32 [[B]], [[A]]
7574
; CHECK-NEXT: br i1 [[COND]], label [[L1:%.*]], label [[END:%.*]]
7675
; CHECK: l1:
77-
; CHECK-NEXT: [[B_CMP1:%.*]] = icmp ult i32 [[B]], 65535
78-
; CHECK-NEXT: call void @use(i1 [[B_CMP1]])
76+
; CHECK-NEXT: call void @use(i1 true)
7977
; CHECK-NEXT: [[B_CMP2:%.*]] = icmp ult i32 [[B]], 65534
8078
; CHECK-NEXT: call void @use(i1 [[B_CMP2]])
8179
; CHECK-NEXT: ret void

0 commit comments

Comments
 (0)