Skip to content

Commit deb6463

Browse files
goldsteinnyuxuanchen1997
authored andcommitted
[ValueTracking] Consistently propagate DemandedElts is ComputeNumSignBits
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60250924
1 parent 97b91e0 commit deb6463

File tree

1 file changed

+40
-27
lines changed

1 file changed

+40
-27
lines changed

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3801,7 +3801,8 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
38013801
default: break;
38023802
case Instruction::SExt:
38033803
Tmp = TyBits - U->getOperand(0)->getType()->getScalarSizeInBits();
3804-
return ComputeNumSignBits(U->getOperand(0), Depth + 1, Q) + Tmp;
3804+
return ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q) +
3805+
Tmp;
38053806

38063807
case Instruction::SDiv: {
38073808
const APInt *Denominator;
@@ -3813,7 +3814,8 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
38133814
break;
38143815

38153816
// Calculate the incoming numerator bits.
3816-
unsigned NumBits = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
3817+
unsigned NumBits =
3818+
ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
38173819

38183820
// Add floor(log(C)) bits to the numerator bits.
38193821
return std::min(TyBits, NumBits + Denominator->logBase2());
@@ -3822,7 +3824,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
38223824
}
38233825

38243826
case Instruction::SRem: {
3825-
Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
3827+
Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
38263828

38273829
const APInt *Denominator;
38283830
// srem X, C -> we know that the result is within [-C+1,C) when C is a
@@ -3853,7 +3855,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
38533855
}
38543856

38553857
case Instruction::AShr: {
3856-
Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
3858+
Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
38573859
// ashr X, C -> adds C sign bits. Vectors too.
38583860
const APInt *ShAmt;
38593861
if (match(U->getOperand(1), m_APInt(ShAmt))) {
@@ -3869,7 +3871,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
38693871
const APInt *ShAmt;
38703872
if (match(U->getOperand(1), m_APInt(ShAmt))) {
38713873
// shl destroys sign bits.
3872-
Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
3874+
Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
38733875
if (ShAmt->uge(TyBits) || // Bad shift.
38743876
ShAmt->uge(Tmp)) break; // Shifted all sign bits out.
38753877
Tmp2 = ShAmt->getZExtValue();
@@ -3881,9 +3883,9 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
38813883
case Instruction::Or:
38823884
case Instruction::Xor: // NOT is handled here.
38833885
// Logical binary ops preserve the number of sign bits at the worst.
3884-
Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
3886+
Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
38853887
if (Tmp != 1) {
3886-
Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
3888+
Tmp2 = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
38873889
FirstAnswer = std::min(Tmp, Tmp2);
38883890
// We computed what we know about the sign bits as our first
38893891
// answer. Now proceed to the generic code that uses
@@ -3899,9 +3901,10 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
38993901
if (isSignedMinMaxClamp(U, X, CLow, CHigh))
39003902
return std::min(CLow->getNumSignBits(), CHigh->getNumSignBits());
39013903

3902-
Tmp = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
3903-
if (Tmp == 1) break;
3904-
Tmp2 = ComputeNumSignBits(U->getOperand(2), Depth + 1, Q);
3904+
Tmp = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
3905+
if (Tmp == 1)
3906+
break;
3907+
Tmp2 = ComputeNumSignBits(U->getOperand(2), DemandedElts, Depth + 1, Q);
39053908
return std::min(Tmp, Tmp2);
39063909
}
39073910

@@ -3915,7 +3918,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
39153918
if (const auto *CRHS = dyn_cast<Constant>(U->getOperand(1)))
39163919
if (CRHS->isAllOnesValue()) {
39173920
KnownBits Known(TyBits);
3918-
computeKnownBits(U->getOperand(0), Known, Depth + 1, Q);
3921+
computeKnownBits(U->getOperand(0), DemandedElts, Known, Depth + 1, Q);
39193922

39203923
// If the input is known to be 0 or 1, the output is 0/-1, which is
39213924
// all sign bits set.
@@ -3928,19 +3931,21 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
39283931
return Tmp;
39293932
}
39303933

3931-
Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
3932-
if (Tmp2 == 1) break;
3934+
Tmp2 = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
3935+
if (Tmp2 == 1)
3936+
break;
39333937
return std::min(Tmp, Tmp2) - 1;
39343938

39353939
case Instruction::Sub:
3936-
Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
3937-
if (Tmp2 == 1) break;
3940+
Tmp2 = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
3941+
if (Tmp2 == 1)
3942+
break;
39383943

39393944
// Handle NEG.
39403945
if (const auto *CLHS = dyn_cast<Constant>(U->getOperand(0)))
39413946
if (CLHS->isNullValue()) {
39423947
KnownBits Known(TyBits);
3943-
computeKnownBits(U->getOperand(1), Known, Depth + 1, Q);
3948+
computeKnownBits(U->getOperand(1), DemandedElts, Known, Depth + 1, Q);
39443949
// If the input is known to be 0 or 1, the output is 0/-1, which is
39453950
// all sign bits set.
39463951
if ((Known.Zero | 1).isAllOnes())
@@ -3957,17 +3962,22 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
39573962

39583963
// Sub can have at most one carry bit. Thus we know that the output
39593964
// is, at worst, one more bit than the inputs.
3960-
Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
3961-
if (Tmp == 1) break;
3965+
Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
3966+
if (Tmp == 1)
3967+
break;
39623968
return std::min(Tmp, Tmp2) - 1;
39633969

39643970
case Instruction::Mul: {
39653971
// The output of the Mul can be at most twice the valid bits in the
39663972
// inputs.
3967-
unsigned SignBitsOp0 = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
3968-
if (SignBitsOp0 == 1) break;
3969-
unsigned SignBitsOp1 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
3970-
if (SignBitsOp1 == 1) break;
3973+
unsigned SignBitsOp0 =
3974+
ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
3975+
if (SignBitsOp0 == 1)
3976+
break;
3977+
unsigned SignBitsOp1 =
3978+
ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
3979+
if (SignBitsOp1 == 1)
3980+
break;
39713981
unsigned OutValidBits =
39723982
(TyBits - SignBitsOp0 + 1) + (TyBits - SignBitsOp1 + 1);
39733983
return OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1;
@@ -3988,8 +3998,8 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
39883998
for (unsigned i = 0, e = NumIncomingValues; i != e; ++i) {
39893999
if (Tmp == 1) return Tmp;
39904000
RecQ.CxtI = PN->getIncomingBlock(i)->getTerminator();
3991-
Tmp = std::min(
3992-
Tmp, ComputeNumSignBits(PN->getIncomingValue(i), Depth + 1, RecQ));
4001+
Tmp = std::min(Tmp, ComputeNumSignBits(PN->getIncomingValue(i),
4002+
DemandedElts, Depth + 1, RecQ));
39934003
}
39944004
return Tmp;
39954005
}
@@ -4050,10 +4060,13 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
40504060
case Instruction::Call: {
40514061
if (const auto *II = dyn_cast<IntrinsicInst>(U)) {
40524062
switch (II->getIntrinsicID()) {
4053-
default: break;
4063+
default:
4064+
break;
40544065
case Intrinsic::abs:
4055-
Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
4056-
if (Tmp == 1) break;
4066+
Tmp =
4067+
ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
4068+
if (Tmp == 1)
4069+
break;
40574070

40584071
// Absolute value reduces number of sign bits by at most 1.
40594072
return Tmp - 1;

0 commit comments

Comments
 (0)