Skip to content

Commit 6d7cf52

Browse files
authored
[ValueTracking] Improve KnownBits for signed min-max clamping (llvm#120576)
A signed min-max clamp is the sequence of smin and smax intrinsics, which constrain a signed value into the range: smin <= value <= smax. The patch improves the calculation of KnownBits for a value subjected to the signed clamping.
1 parent 3469996 commit 6d7cf52

File tree

2 files changed

+325
-49
lines changed

2 files changed

+325
-49
lines changed

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 59 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,63 @@ void llvm::adjustKnownBitsForSelectArm(KnownBits &Known, Value *Cond,
10651065
Known = CondRes;
10661066
}
10671067

1068+
// Match a signed min+max clamp pattern like smax(smin(In, CHigh), CLow).
1069+
// Returns the input and lower/upper bounds.
1070+
static bool isSignedMinMaxClamp(const Value *Select, const Value *&In,
1071+
const APInt *&CLow, const APInt *&CHigh) {
1072+
assert(isa<Operator>(Select) &&
1073+
cast<Operator>(Select)->getOpcode() == Instruction::Select &&
1074+
"Input should be a Select!");
1075+
1076+
const Value *LHS = nullptr, *RHS = nullptr;
1077+
SelectPatternFlavor SPF = matchSelectPattern(Select, LHS, RHS).Flavor;
1078+
if (SPF != SPF_SMAX && SPF != SPF_SMIN)
1079+
return false;
1080+
1081+
if (!match(RHS, m_APInt(CLow)))
1082+
return false;
1083+
1084+
const Value *LHS2 = nullptr, *RHS2 = nullptr;
1085+
SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor;
1086+
if (getInverseMinMaxFlavor(SPF) != SPF2)
1087+
return false;
1088+
1089+
if (!match(RHS2, m_APInt(CHigh)))
1090+
return false;
1091+
1092+
if (SPF == SPF_SMIN)
1093+
std::swap(CLow, CHigh);
1094+
1095+
In = LHS2;
1096+
return CLow->sle(*CHigh);
1097+
}
1098+
1099+
static bool isSignedMinMaxIntrinsicClamp(const IntrinsicInst *II,
1100+
const APInt *&CLow,
1101+
const APInt *&CHigh) {
1102+
assert((II->getIntrinsicID() == Intrinsic::smin ||
1103+
II->getIntrinsicID() == Intrinsic::smax) &&
1104+
"Must be smin/smax");
1105+
1106+
Intrinsic::ID InverseID = getInverseMinMaxIntrinsic(II->getIntrinsicID());
1107+
auto *InnerII = dyn_cast<IntrinsicInst>(II->getArgOperand(0));
1108+
if (!InnerII || InnerII->getIntrinsicID() != InverseID ||
1109+
!match(II->getArgOperand(1), m_APInt(CLow)) ||
1110+
!match(InnerII->getArgOperand(1), m_APInt(CHigh)))
1111+
return false;
1112+
1113+
if (II->getIntrinsicID() == Intrinsic::smin)
1114+
std::swap(CLow, CHigh);
1115+
return CLow->sle(*CHigh);
1116+
}
1117+
1118+
static void unionWithMinMaxIntrinsicClamp(const IntrinsicInst *II,
1119+
KnownBits &Known) {
1120+
const APInt *CLow, *CHigh;
1121+
if (isSignedMinMaxIntrinsicClamp(II, CLow, CHigh))
1122+
Known = Known.unionWith(ConstantRange(*CLow, *CHigh + 1).toKnownBits());
1123+
}
1124+
10681125
static void computeKnownBitsFromOperator(const Operator *I,
10691126
const APInt &DemandedElts,
10701127
KnownBits &Known, unsigned Depth,
@@ -1804,11 +1861,13 @@ static void computeKnownBitsFromOperator(const Operator *I,
18041861
computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
18051862
computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
18061863
Known = KnownBits::smin(Known, Known2);
1864+
unionWithMinMaxIntrinsicClamp(II, Known);
18071865
break;
18081866
case Intrinsic::smax:
18091867
computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
18101868
computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
18111869
Known = KnownBits::smax(Known, Known2);
1870+
unionWithMinMaxIntrinsicClamp(II, Known);
18121871
break;
18131872
case Intrinsic::ptrmask: {
18141873
computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
@@ -3751,55 +3810,6 @@ static bool isKnownNonEqual(const Value *V1, const Value *V2,
37513810
return false;
37523811
}
37533812

3754-
// Match a signed min+max clamp pattern like smax(smin(In, CHigh), CLow).
3755-
// Returns the input and lower/upper bounds.
3756-
static bool isSignedMinMaxClamp(const Value *Select, const Value *&In,
3757-
const APInt *&CLow, const APInt *&CHigh) {
3758-
assert(isa<Operator>(Select) &&
3759-
cast<Operator>(Select)->getOpcode() == Instruction::Select &&
3760-
"Input should be a Select!");
3761-
3762-
const Value *LHS = nullptr, *RHS = nullptr;
3763-
SelectPatternFlavor SPF = matchSelectPattern(Select, LHS, RHS).Flavor;
3764-
if (SPF != SPF_SMAX && SPF != SPF_SMIN)
3765-
return false;
3766-
3767-
if (!match(RHS, m_APInt(CLow)))
3768-
return false;
3769-
3770-
const Value *LHS2 = nullptr, *RHS2 = nullptr;
3771-
SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor;
3772-
if (getInverseMinMaxFlavor(SPF) != SPF2)
3773-
return false;
3774-
3775-
if (!match(RHS2, m_APInt(CHigh)))
3776-
return false;
3777-
3778-
if (SPF == SPF_SMIN)
3779-
std::swap(CLow, CHigh);
3780-
3781-
In = LHS2;
3782-
return CLow->sle(*CHigh);
3783-
}
3784-
3785-
static bool isSignedMinMaxIntrinsicClamp(const IntrinsicInst *II,
3786-
const APInt *&CLow,
3787-
const APInt *&CHigh) {
3788-
assert((II->getIntrinsicID() == Intrinsic::smin ||
3789-
II->getIntrinsicID() == Intrinsic::smax) && "Must be smin/smax");
3790-
3791-
Intrinsic::ID InverseID = getInverseMinMaxIntrinsic(II->getIntrinsicID());
3792-
auto *InnerII = dyn_cast<IntrinsicInst>(II->getArgOperand(0));
3793-
if (!InnerII || InnerII->getIntrinsicID() != InverseID ||
3794-
!match(II->getArgOperand(1), m_APInt(CLow)) ||
3795-
!match(InnerII->getArgOperand(1), m_APInt(CHigh)))
3796-
return false;
3797-
3798-
if (II->getIntrinsicID() == Intrinsic::smin)
3799-
std::swap(CLow, CHigh);
3800-
return CLow->sle(*CHigh);
3801-
}
3802-
38033813
/// For vector constants, loop over the elements and find the constant with the
38043814
/// minimum number of sign bits. Return 0 if the value is not a vector constant
38053815
/// or if any element was not analyzed; otherwise, return the count for the

0 commit comments

Comments
 (0)