Skip to content

Commit 9ced5e9

Browse files
committed
[DAG] Replace getValid*ShiftAmountConstant helpers with getValid*ShiftAmount helpers to support KnownBits analysis
The getValidShiftAmountConstant/getValidMinimumShiftAmountConstant/getValidMaximumShiftAmountConstant helpers only worked with constant shift amounts, which could be problematic after type legalization (e.g. v2i64 might be split into v4i32 on some targets such as Thumb2 MVE). This patch proposes we generalize these helpers to work with KnownBits if a scalar/buildvector constant isn't available. Most restrictions are the same - the helper fails if any shift amount is out of bounds, getValidShiftConstant must be a specific constant uniform etc. However, getValidMinimumShiftAmount/getValidMaximumShiftAmount now can return bounds values that aren't values in the actual data, as they are based off the common KnownBits of every vector element. This addresses feedback on #92096
1 parent ebc6c28 commit 9ced5e9

File tree

5 files changed

+153
-135
lines changed

5 files changed

+153
-135
lines changed

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2159,36 +2159,38 @@ class SelectionDAG {
21592159
/// splatted value it will return SDValue().
21602160
SDValue getSplatValue(SDValue V, bool LegalTypes = false);
21612161

2162-
/// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
2162+
/// If a SHL/SRA/SRL node \p V has an uniform shift amount
21632163
/// that is less than the element bit-width of the shift node, return it.
2164-
const APInt *getValidShiftAmountConstant(SDValue V,
2165-
const APInt &DemandedElts) const;
2164+
std::optional<uint64_t> getValidShiftAmount(SDValue V,
2165+
const APInt &DemandedElts,
2166+
unsigned Depth = 0) const;
21662167

2167-
/// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
2168+
/// If a SHL/SRA/SRL node \p V has an uniform shift amount
21682169
/// that is less than the element bit-width of the shift node, return it.
2169-
const APInt *getValidShiftAmountConstant(SDValue V) const;
2170-
2171-
/// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
2172-
/// than the element bit-width of the shift node, return the minimum value.
2173-
const APInt *
2174-
getValidMinimumShiftAmountConstant(SDValue V,
2175-
const APInt &DemandedElts) const;
2176-
2177-
/// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
2178-
/// than the element bit-width of the shift node, return the minimum value.
2179-
const APInt *
2180-
getValidMinimumShiftAmountConstant(SDValue V) const;
2181-
2182-
/// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
2183-
/// than the element bit-width of the shift node, return the maximum value.
2184-
const APInt *
2185-
getValidMaximumShiftAmountConstant(SDValue V,
2186-
const APInt &DemandedElts) const;
2187-
2188-
/// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
2189-
/// than the element bit-width of the shift node, return the maximum value.
2190-
const APInt *
2191-
getValidMaximumShiftAmountConstant(SDValue V) const;
2170+
std::optional<uint64_t> getValidShiftAmount(SDValue V,
2171+
unsigned Depth = 0) const;
2172+
2173+
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
2174+
/// element bit-width of the shift node, return the minimum possible value.
2175+
std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
2176+
const APInt &DemandedElts,
2177+
unsigned Depth = 0) const;
2178+
2179+
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
2180+
/// element bit-width of the shift node, return the minimum possible value.
2181+
std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
2182+
unsigned Depth = 0) const;
2183+
2184+
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
2185+
/// element bit-width of the shift node, return the maximum possible value.
2186+
std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
2187+
const APInt &DemandedElts,
2188+
unsigned Depth = 0) const;
2189+
2190+
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
2191+
/// element bit-width of the shift node, return the maximum possible value.
2192+
std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
2193+
unsigned Depth = 0) const;
21922194

21932195
/// Match a binop + shuffle pyramid that represents a horizontal reduction
21942196
/// over the elements of a vector starting from the EXTRACT_VECTOR_ELT node /p

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 89 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -3009,9 +3009,9 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
30093009
return SDValue();
30103010
}
30113011

3012-
const APInt *
3013-
SelectionDAG::getValidShiftAmountConstant(SDValue V,
3014-
const APInt &DemandedElts) const {
3012+
std::optional<uint64_t>
3013+
SelectionDAG::getValidShiftAmount(SDValue V, const APInt &DemandedElts,
3014+
unsigned Depth) const {
30153015
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
30163016
V.getOpcode() == ISD::SRA) &&
30173017
"Unknown shift node");
@@ -3020,91 +3020,111 @@ SelectionDAG::getValidShiftAmountConstant(SDValue V,
30203020
// Shifting more than the bitwidth is not valid.
30213021
const APInt &ShAmt = SA->getAPIntValue();
30223022
if (ShAmt.ult(BitWidth))
3023-
return &ShAmt;
3023+
return ShAmt.getZExtValue();
3024+
} else {
3025+
KnownBits KnownAmt =
3026+
computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
3027+
if (KnownAmt.isConstant() && KnownAmt.getConstant().ult(BitWidth))
3028+
return KnownAmt.getConstant().getZExtValue();
30243029
}
3025-
return nullptr;
3030+
return std::nullopt;
30263031
}
30273032

3028-
const APInt *SelectionDAG::getValidShiftAmountConstant(SDValue V) const {
3033+
std::optional<uint64_t>
3034+
SelectionDAG::getValidShiftAmount(SDValue V, unsigned Depth) const {
30293035
EVT VT = V.getValueType();
30303036
APInt DemandedElts = VT.isFixedLengthVector()
30313037
? APInt::getAllOnes(VT.getVectorNumElements())
30323038
: APInt(1, 1);
3033-
return getValidShiftAmountConstant(V, DemandedElts);
3039+
return getValidShiftAmount(V, DemandedElts, Depth);
30343040
}
30353041

3036-
const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(
3037-
SDValue V, const APInt &DemandedElts) const {
3042+
std::optional<uint64_t>
3043+
SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
3044+
unsigned Depth) const {
30383045
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
30393046
V.getOpcode() == ISD::SRA) &&
30403047
"Unknown shift node");
3041-
if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
3042-
return ValidAmt;
30433048
unsigned BitWidth = V.getScalarValueSizeInBits();
3044-
auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
3045-
if (!BV)
3046-
return nullptr;
3047-
const APInt *MinShAmt = nullptr;
3048-
for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
3049-
if (!DemandedElts[i])
3050-
continue;
3051-
auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
3052-
if (!SA)
3053-
return nullptr;
3054-
// Shifting more than the bitwidth is not valid.
3055-
const APInt &ShAmt = SA->getAPIntValue();
3056-
if (ShAmt.uge(BitWidth))
3057-
return nullptr;
3058-
if (MinShAmt && MinShAmt->ule(ShAmt))
3059-
continue;
3060-
MinShAmt = &ShAmt;
3049+
if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
3050+
const APInt *MinShAmt = nullptr;
3051+
for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
3052+
if (!DemandedElts[i])
3053+
continue;
3054+
auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
3055+
if (!SA) {
3056+
MinShAmt = nullptr;
3057+
break;
3058+
}
3059+
// Shifting more than the bitwidth is not valid.
3060+
const APInt &ShAmt = SA->getAPIntValue();
3061+
if (ShAmt.uge(BitWidth))
3062+
return std::nullopt;
3063+
if (MinShAmt && MinShAmt->ule(ShAmt))
3064+
continue;
3065+
MinShAmt = &ShAmt;
3066+
}
3067+
if (MinShAmt)
3068+
return MinShAmt->getZExtValue();
30613069
}
3062-
return MinShAmt;
3070+
KnownBits KnownAmt =
3071+
computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
3072+
if (KnownAmt.getMaxValue().ult(BitWidth))
3073+
return KnownAmt.getMinValue().getZExtValue();
3074+
return std::nullopt;
30633075
}
30643076

3065-
const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(SDValue V) const {
3077+
std::optional<uint64_t>
3078+
SelectionDAG::getValidMinimumShiftAmount(SDValue V, unsigned Depth) const {
30663079
EVT VT = V.getValueType();
30673080
APInt DemandedElts = VT.isFixedLengthVector()
30683081
? APInt::getAllOnes(VT.getVectorNumElements())
30693082
: APInt(1, 1);
3070-
return getValidMinimumShiftAmountConstant(V, DemandedElts);
3083+
return getValidMinimumShiftAmount(V, DemandedElts, Depth);
30713084
}
30723085

3073-
const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
3074-
SDValue V, const APInt &DemandedElts) const {
3086+
std::optional<uint64_t>
3087+
SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
3088+
unsigned Depth) const {
30753089
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
30763090
V.getOpcode() == ISD::SRA) &&
30773091
"Unknown shift node");
3078-
if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
3079-
return ValidAmt;
30803092
unsigned BitWidth = V.getScalarValueSizeInBits();
3081-
auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
3082-
if (!BV)
3083-
return nullptr;
3084-
const APInt *MaxShAmt = nullptr;
3085-
for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
3086-
if (!DemandedElts[i])
3087-
continue;
3088-
auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
3089-
if (!SA)
3090-
return nullptr;
3091-
// Shifting more than the bitwidth is not valid.
3092-
const APInt &ShAmt = SA->getAPIntValue();
3093-
if (ShAmt.uge(BitWidth))
3094-
return nullptr;
3095-
if (MaxShAmt && MaxShAmt->uge(ShAmt))
3096-
continue;
3097-
MaxShAmt = &ShAmt;
3093+
if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
3094+
const APInt *MaxShAmt = nullptr;
3095+
for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
3096+
if (!DemandedElts[i])
3097+
continue;
3098+
auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
3099+
if (!SA) {
3100+
MaxShAmt = nullptr;
3101+
break;
3102+
}
3103+
// Shifting more than the bitwidth is not valid.
3104+
const APInt &ShAmt = SA->getAPIntValue();
3105+
if (ShAmt.uge(BitWidth))
3106+
return std::nullopt;
3107+
if (MaxShAmt && MaxShAmt->uge(ShAmt))
3108+
continue;
3109+
MaxShAmt = &ShAmt;
3110+
}
3111+
if (MaxShAmt)
3112+
return MaxShAmt->getZExtValue();
30983113
}
3099-
return MaxShAmt;
3114+
KnownBits KnownAmt =
3115+
computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
3116+
if (KnownAmt.getMaxValue().ult(BitWidth))
3117+
return KnownAmt.getMaxValue().getZExtValue();
3118+
return std::nullopt;
31003119
}
31013120

3102-
const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(SDValue V) const {
3121+
std::optional<uint64_t>
3122+
SelectionDAG::getValidMaximumShiftAmount(SDValue V, unsigned Depth) const {
31033123
EVT VT = V.getValueType();
31043124
APInt DemandedElts = VT.isFixedLengthVector()
31053125
? APInt::getAllOnes(VT.getVectorNumElements())
31063126
: APInt(1, 1);
3107-
return getValidMaximumShiftAmountConstant(V, DemandedElts);
3127+
return getValidMaximumShiftAmount(V, DemandedElts, Depth);
31083128
}
31093129

31103130
/// Determine which bits of Op are known to be either zero or one and return
@@ -3569,9 +3589,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
35693589
Known = KnownBits::shl(Known, Known2, NUW, NSW, ShAmtNonZero);
35703590

35713591
// Minimum shift low bits are known zero.
3572-
if (const APInt *ShMinAmt =
3573-
getValidMinimumShiftAmountConstant(Op, DemandedElts))
3574-
Known.Zero.setLowBits(ShMinAmt->getZExtValue());
3592+
if (std::optional<uint64_t> ShMinAmt =
3593+
getValidMinimumShiftAmount(Op, DemandedElts, Depth))
3594+
Known.Zero.setLowBits(*ShMinAmt);
35753595
break;
35763596
}
35773597
case ISD::SRL:
@@ -3581,9 +3601,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
35813601
Op->getFlags().hasExact());
35823602

35833603
// Minimum shift high bits are known zero.
3584-
if (const APInt *ShMinAmt =
3585-
getValidMinimumShiftAmountConstant(Op, DemandedElts))
3586-
Known.Zero.setHighBits(ShMinAmt->getZExtValue());
3604+
if (std::optional<uint64_t> ShMinAmt =
3605+
getValidMinimumShiftAmount(Op, DemandedElts, Depth))
3606+
Known.Zero.setHighBits(*ShMinAmt);
35873607
break;
35883608
case ISD::SRA:
35893609
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
@@ -4587,17 +4607,17 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
45874607
case ISD::SRA:
45884608
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
45894609
// SRA X, C -> adds C sign bits.
4590-
if (const APInt *ShAmt =
4591-
getValidMinimumShiftAmountConstant(Op, DemandedElts))
4592-
Tmp = std::min<uint64_t>(Tmp + ShAmt->getZExtValue(), VTBits);
4610+
if (std::optional<uint64_t> ShAmt =
4611+
getValidMinimumShiftAmount(Op, DemandedElts, Depth))
4612+
Tmp = std::min<uint64_t>(Tmp + *ShAmt, VTBits);
45934613
return Tmp;
45944614
case ISD::SHL:
4595-
if (const APInt *ShAmt =
4596-
getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
4615+
if (std::optional<uint64_t> ShAmt =
4616+
getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
45974617
// shl destroys sign bits, ensure it doesn't shift out all sign bits.
45984618
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
4599-
if (ShAmt->ult(Tmp))
4600-
return Tmp - ShAmt->getZExtValue();
4619+
if (*ShAmt < Tmp)
4620+
return Tmp - *ShAmt;
46014621
}
46024622
break;
46034623
case ISD::AND:
@@ -5270,7 +5290,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
52705290
case ISD::SRL:
52715291
case ISD::SRA:
52725292
// If the max shift amount isn't in range, then the shift can create poison.
5273-
return !getValidMaximumShiftAmountConstant(Op, DemandedElts);
5293+
return !getValidMaximumShiftAmount(Op, DemandedElts, Depth);
52745294

52755295
case ISD::SCALAR_TO_VECTOR:
52765296
// Check if we demand any upper (undef) elements.

0 commit comments

Comments
 (0)