Skip to content

Commit bbe3c19

Browse files
committed
[DAG] Replace getValid*ShiftAmountConstant helpers with getValid*ShiftAmount helpers to support KnownBits analysis
The getValidShiftAmountConstant/getValidMinimumShiftAmountConstant/getValidMaximumShiftAmountConstant helpers only worked with constant shift amounts, which could be problematic after type legalization (e.g. v2i64 might be split into v4i32 on some targets such as Thumb2 MVE). This patch proposes we generalize these helpers to work with KnownBits if a scalar/buildvector constant isn't available. Most restrictions are the same - the helper fails if any shift amount is out of bounds, getValidShiftConstant must be a specific constant uniform etc. However, getValidMinimumShiftAmount/getValidMaximumShiftAmount now can return bounds values that aren't values in the actual data, as they are based off the common KnownBits of every vector element. This addresses feedback on #92096
1 parent 7c265e9 commit bbe3c19

File tree

5 files changed

+155
-135
lines changed

5 files changed

+155
-135
lines changed

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2159,36 +2159,38 @@ class SelectionDAG {
21592159
/// splatted value it will return SDValue().
21602160
SDValue getSplatValue(SDValue V, bool LegalTypes = false);
21612161

2162-
/// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
2162+
/// If a SHL/SRA/SRL node \p V has an uniform shift amount
21632163
/// that is less than the element bit-width of the shift node, return it.
2164-
const APInt *getValidShiftAmountConstant(SDValue V,
2165-
const APInt &DemandedElts) const;
2164+
std::optional<uint64_t> getValidShiftAmount(SDValue V,
2165+
const APInt &DemandedElts,
2166+
unsigned Depth = 0) const;
21662167

2167-
/// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
2168+
/// If a SHL/SRA/SRL node \p V has an uniform shift amount
21682169
/// that is less than the element bit-width of the shift node, return it.
2169-
const APInt *getValidShiftAmountConstant(SDValue V) const;
2170-
2171-
/// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
2172-
/// than the element bit-width of the shift node, return the minimum value.
2173-
const APInt *
2174-
getValidMinimumShiftAmountConstant(SDValue V,
2175-
const APInt &DemandedElts) const;
2176-
2177-
/// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
2178-
/// than the element bit-width of the shift node, return the minimum value.
2179-
const APInt *
2180-
getValidMinimumShiftAmountConstant(SDValue V) const;
2181-
2182-
/// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
2183-
/// than the element bit-width of the shift node, return the maximum value.
2184-
const APInt *
2185-
getValidMaximumShiftAmountConstant(SDValue V,
2186-
const APInt &DemandedElts) const;
2187-
2188-
/// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
2189-
/// than the element bit-width of the shift node, return the maximum value.
2190-
const APInt *
2191-
getValidMaximumShiftAmountConstant(SDValue V) const;
2170+
std::optional<uint64_t> getValidShiftAmount(SDValue V,
2171+
unsigned Depth = 0) const;
2172+
2173+
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
2174+
/// element bit-width of the shift node, return the minimum possible value.
2175+
std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
2176+
const APInt &DemandedElts,
2177+
unsigned Depth = 0) const;
2178+
2179+
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
2180+
/// element bit-width of the shift node, return the minimum possible value.
2181+
std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
2182+
unsigned Depth = 0) const;
2183+
2184+
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
2185+
/// element bit-width of the shift node, return the maximum possible value.
2186+
std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
2187+
const APInt &DemandedElts,
2188+
unsigned Depth = 0) const;
2189+
2190+
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
2191+
/// element bit-width of the shift node, return the maximum possible value.
2192+
std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
2193+
unsigned Depth = 0) const;
21922194

21932195
/// Match a binop + shuffle pyramid that represents a horizontal reduction
21942196
/// over the elements of a vector starting from the EXTRACT_VECTOR_ELT node /p

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 91 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -3009,9 +3009,9 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
30093009
return SDValue();
30103010
}
30113011

3012-
const APInt *
3013-
SelectionDAG::getValidShiftAmountConstant(SDValue V,
3014-
const APInt &DemandedElts) const {
3012+
std::optional<uint64_t>
3013+
SelectionDAG::getValidShiftAmount(SDValue V, const APInt &DemandedElts,
3014+
unsigned Depth) const {
30153015
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
30163016
V.getOpcode() == ISD::SRA) &&
30173017
"Unknown shift node");
@@ -3020,91 +3020,113 @@ SelectionDAG::getValidShiftAmountConstant(SDValue V,
30203020
// Shifting more than the bitwidth is not valid.
30213021
const APInt &ShAmt = SA->getAPIntValue();
30223022
if (ShAmt.ult(BitWidth))
3023-
return &ShAmt;
3023+
return ShAmt.getZExtValue();
3024+
} else {
3025+
// Use computeKnownBits to find a hidden constant (usually type legalized).
3026+
// e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
3027+
KnownBits KnownAmt =
3028+
computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
3029+
if (KnownAmt.isConstant() && KnownAmt.getConstant().ult(BitWidth))
3030+
return KnownAmt.getConstant().getZExtValue();
30243031
}
3025-
return nullptr;
3032+
return std::nullopt;
30263033
}
30273034

3028-
const APInt *SelectionDAG::getValidShiftAmountConstant(SDValue V) const {
3035+
std::optional<uint64_t>
3036+
SelectionDAG::getValidShiftAmount(SDValue V, unsigned Depth) const {
30293037
EVT VT = V.getValueType();
30303038
APInt DemandedElts = VT.isFixedLengthVector()
30313039
? APInt::getAllOnes(VT.getVectorNumElements())
30323040
: APInt(1, 1);
3033-
return getValidShiftAmountConstant(V, DemandedElts);
3041+
return getValidShiftAmount(V, DemandedElts, Depth);
30343042
}
30353043

3036-
const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(
3037-
SDValue V, const APInt &DemandedElts) const {
3044+
std::optional<uint64_t>
3045+
SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
3046+
unsigned Depth) const {
30383047
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
30393048
V.getOpcode() == ISD::SRA) &&
30403049
"Unknown shift node");
3041-
if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
3042-
return ValidAmt;
30433050
unsigned BitWidth = V.getScalarValueSizeInBits();
3044-
auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
3045-
if (!BV)
3046-
return nullptr;
3047-
const APInt *MinShAmt = nullptr;
3048-
for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
3049-
if (!DemandedElts[i])
3050-
continue;
3051-
auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
3052-
if (!SA)
3053-
return nullptr;
3054-
// Shifting more than the bitwidth is not valid.
3055-
const APInt &ShAmt = SA->getAPIntValue();
3056-
if (ShAmt.uge(BitWidth))
3057-
return nullptr;
3058-
if (MinShAmt && MinShAmt->ule(ShAmt))
3059-
continue;
3060-
MinShAmt = &ShAmt;
3051+
if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
3052+
const APInt *MinShAmt = nullptr;
3053+
for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
3054+
if (!DemandedElts[i])
3055+
continue;
3056+
auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
3057+
if (!SA) {
3058+
MinShAmt = nullptr;
3059+
break;
3060+
}
3061+
// Shifting more than the bitwidth is not valid.
3062+
const APInt &ShAmt = SA->getAPIntValue();
3063+
if (ShAmt.uge(BitWidth))
3064+
return std::nullopt;
3065+
if (MinShAmt && MinShAmt->ule(ShAmt))
3066+
continue;
3067+
MinShAmt = &ShAmt;
3068+
}
3069+
if (MinShAmt)
3070+
return MinShAmt->getZExtValue();
30613071
}
3062-
return MinShAmt;
3072+
KnownBits KnownAmt =
3073+
computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
3074+
if (KnownAmt.getMaxValue().ult(BitWidth))
3075+
return KnownAmt.getMinValue().getZExtValue();
3076+
return std::nullopt;
30633077
}
30643078

3065-
const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(SDValue V) const {
3079+
std::optional<uint64_t>
3080+
SelectionDAG::getValidMinimumShiftAmount(SDValue V, unsigned Depth) const {
30663081
EVT VT = V.getValueType();
30673082
APInt DemandedElts = VT.isFixedLengthVector()
30683083
? APInt::getAllOnes(VT.getVectorNumElements())
30693084
: APInt(1, 1);
3070-
return getValidMinimumShiftAmountConstant(V, DemandedElts);
3085+
return getValidMinimumShiftAmount(V, DemandedElts, Depth);
30713086
}
30723087

3073-
const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
3074-
SDValue V, const APInt &DemandedElts) const {
3088+
std::optional<uint64_t>
3089+
SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
3090+
unsigned Depth) const {
30753091
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
30763092
V.getOpcode() == ISD::SRA) &&
30773093
"Unknown shift node");
3078-
if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
3079-
return ValidAmt;
30803094
unsigned BitWidth = V.getScalarValueSizeInBits();
3081-
auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
3082-
if (!BV)
3083-
return nullptr;
3084-
const APInt *MaxShAmt = nullptr;
3085-
for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
3086-
if (!DemandedElts[i])
3087-
continue;
3088-
auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
3089-
if (!SA)
3090-
return nullptr;
3091-
// Shifting more than the bitwidth is not valid.
3092-
const APInt &ShAmt = SA->getAPIntValue();
3093-
if (ShAmt.uge(BitWidth))
3094-
return nullptr;
3095-
if (MaxShAmt && MaxShAmt->uge(ShAmt))
3096-
continue;
3097-
MaxShAmt = &ShAmt;
3095+
if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
3096+
const APInt *MaxShAmt = nullptr;
3097+
for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
3098+
if (!DemandedElts[i])
3099+
continue;
3100+
auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
3101+
if (!SA) {
3102+
MaxShAmt = nullptr;
3103+
break;
3104+
}
3105+
// Shifting more than the bitwidth is not valid.
3106+
const APInt &ShAmt = SA->getAPIntValue();
3107+
if (ShAmt.uge(BitWidth))
3108+
return std::nullopt;
3109+
if (MaxShAmt && MaxShAmt->uge(ShAmt))
3110+
continue;
3111+
MaxShAmt = &ShAmt;
3112+
}
3113+
if (MaxShAmt)
3114+
return MaxShAmt->getZExtValue();
30983115
}
3099-
return MaxShAmt;
3116+
KnownBits KnownAmt =
3117+
computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
3118+
if (KnownAmt.getMaxValue().ult(BitWidth))
3119+
return KnownAmt.getMaxValue().getZExtValue();
3120+
return std::nullopt;
31003121
}
31013122

3102-
const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(SDValue V) const {
3123+
std::optional<uint64_t>
3124+
SelectionDAG::getValidMaximumShiftAmount(SDValue V, unsigned Depth) const {
31033125
EVT VT = V.getValueType();
31043126
APInt DemandedElts = VT.isFixedLengthVector()
31053127
? APInt::getAllOnes(VT.getVectorNumElements())
31063128
: APInt(1, 1);
3107-
return getValidMaximumShiftAmountConstant(V, DemandedElts);
3129+
return getValidMaximumShiftAmount(V, DemandedElts, Depth);
31083130
}
31093131

31103132
/// Determine which bits of Op are known to be either zero or one and return
@@ -3569,9 +3591,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
35693591
Known = KnownBits::shl(Known, Known2, NUW, NSW, ShAmtNonZero);
35703592

35713593
// Minimum shift low bits are known zero.
3572-
if (const APInt *ShMinAmt =
3573-
getValidMinimumShiftAmountConstant(Op, DemandedElts))
3574-
Known.Zero.setLowBits(ShMinAmt->getZExtValue());
3594+
if (std::optional<uint64_t> ShMinAmt =
3595+
getValidMinimumShiftAmount(Op, DemandedElts, Depth))
3596+
Known.Zero.setLowBits(*ShMinAmt);
35753597
break;
35763598
}
35773599
case ISD::SRL:
@@ -3581,9 +3603,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
35813603
Op->getFlags().hasExact());
35823604

35833605
// Minimum shift high bits are known zero.
3584-
if (const APInt *ShMinAmt =
3585-
getValidMinimumShiftAmountConstant(Op, DemandedElts))
3586-
Known.Zero.setHighBits(ShMinAmt->getZExtValue());
3606+
if (std::optional<uint64_t> ShMinAmt =
3607+
getValidMinimumShiftAmount(Op, DemandedElts, Depth))
3608+
Known.Zero.setHighBits(*ShMinAmt);
35873609
break;
35883610
case ISD::SRA:
35893611
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
@@ -4587,17 +4609,17 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
45874609
case ISD::SRA:
45884610
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
45894611
// SRA X, C -> adds C sign bits.
4590-
if (const APInt *ShAmt =
4591-
getValidMinimumShiftAmountConstant(Op, DemandedElts))
4592-
Tmp = std::min<uint64_t>(Tmp + ShAmt->getZExtValue(), VTBits);
4612+
if (std::optional<uint64_t> ShAmt =
4613+
getValidMinimumShiftAmount(Op, DemandedElts, Depth))
4614+
Tmp = std::min<uint64_t>(Tmp + *ShAmt, VTBits);
45934615
return Tmp;
45944616
case ISD::SHL:
4595-
if (const APInt *ShAmt =
4596-
getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
4617+
if (std::optional<uint64_t> ShAmt =
4618+
getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
45974619
// shl destroys sign bits, ensure it doesn't shift out all sign bits.
45984620
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
4599-
if (ShAmt->ult(Tmp))
4600-
return Tmp - ShAmt->getZExtValue();
4621+
if (*ShAmt < Tmp)
4622+
return Tmp - *ShAmt;
46014623
}
46024624
break;
46034625
case ISD::AND:
@@ -5270,7 +5292,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
52705292
case ISD::SRL:
52715293
case ISD::SRA:
52725294
// If the max shift amount isn't in range, then the shift can create poison.
5273-
return !getValidMaximumShiftAmountConstant(Op, DemandedElts);
5295+
return !getValidMaximumShiftAmount(Op, DemandedElts, Depth);
52745296

52755297
case ISD::SCALAR_TO_VECTOR:
52765298
// Check if we demand any upper (undef) elements.

0 commit comments

Comments
 (0)