Skip to content

Commit 2b1dfd2

Browse files
authored
[DAG] Replace getValid*ShiftAmountConstant helpers with getValid*ShiftAmount helpers to support KnownBits analysis (#93182)
The getValidShiftAmountConstant/getValidMinimumShiftAmountConstant/getValidMaximumShiftAmountConstant helpers only worked with constant shift amounts, which could be problematic after type legalization (e.g. v2i64 might be partially scalarized or split into v4i32 on some targets such as 32-bit x86, Thumb2 MVE). This patch proposes we generalize these helpers to work with ConstantRange+KnownBits if a scalar/buildvector constant isn't available. Most restrictions are the same - the helper fails if any shift amount is out of bounds, getValidShiftConstant must be a specific constant uniform etc. However, getValidMinimumShiftAmount/getValidMaximumShiftAmount now can return bounds values that aren't values in the actual data, as they are based off the common KnownBits of every vector element. This addresses feedback on #92096
1 parent 598f37b commit 2b1dfd2

File tree

5 files changed

+163
-144
lines changed

5 files changed

+163
-144
lines changed

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "llvm/CodeGen/SelectionDAGNodes.h"
3333
#include "llvm/CodeGen/ValueTypes.h"
3434
#include "llvm/CodeGenTypes/MachineValueType.h"
35+
#include "llvm/IR/ConstantRange.h"
3536
#include "llvm/IR/DebugLoc.h"
3637
#include "llvm/IR/Metadata.h"
3738
#include "llvm/Support/Allocator.h"
@@ -2159,36 +2160,44 @@ class SelectionDAG {
21592160
/// splatted value it will return SDValue().
21602161
SDValue getSplatValue(SDValue V, bool LegalTypes = false);
21612162

2162-
/// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
2163+
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
2164+
/// element bit-width of the shift node, return the valid constant range.
2165+
std::optional<ConstantRange>
2166+
getValidShiftAmountRange(SDValue V, const APInt &DemandedElts,
2167+
unsigned Depth) const;
2168+
2169+
/// If a SHL/SRA/SRL node \p V has a uniform shift amount
21632170
/// that is less than the element bit-width of the shift node, return it.
2164-
const APInt *getValidShiftAmountConstant(SDValue V,
2165-
const APInt &DemandedElts) const;
2171+
std::optional<uint64_t> getValidShiftAmount(SDValue V,
2172+
const APInt &DemandedElts,
2173+
unsigned Depth = 0) const;
21662174

2167-
/// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
2175+
/// If a SHL/SRA/SRL node \p V has a uniform shift amount
21682176
/// that is less than the element bit-width of the shift node, return it.
2169-
const APInt *getValidShiftAmountConstant(SDValue V) const;
2170-
2171-
/// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
2172-
/// than the element bit-width of the shift node, return the minimum value.
2173-
const APInt *
2174-
getValidMinimumShiftAmountConstant(SDValue V,
2175-
const APInt &DemandedElts) const;
2176-
2177-
/// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
2178-
/// than the element bit-width of the shift node, return the minimum value.
2179-
const APInt *
2180-
getValidMinimumShiftAmountConstant(SDValue V) const;
2181-
2182-
/// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
2183-
/// than the element bit-width of the shift node, return the maximum value.
2184-
const APInt *
2185-
getValidMaximumShiftAmountConstant(SDValue V,
2186-
const APInt &DemandedElts) const;
2187-
2188-
/// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
2189-
/// than the element bit-width of the shift node, return the maximum value.
2190-
const APInt *
2191-
getValidMaximumShiftAmountConstant(SDValue V) const;
2177+
std::optional<uint64_t> getValidShiftAmount(SDValue V,
2178+
unsigned Depth = 0) const;
2179+
2180+
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
2181+
/// element bit-width of the shift node, return the minimum possible value.
2182+
std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
2183+
const APInt &DemandedElts,
2184+
unsigned Depth = 0) const;
2185+
2186+
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
2187+
/// element bit-width of the shift node, return the minimum possible value.
2188+
std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
2189+
unsigned Depth = 0) const;
2190+
2191+
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
2192+
/// element bit-width of the shift node, return the maximum possible value.
2193+
std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
2194+
const APInt &DemandedElts,
2195+
unsigned Depth = 0) const;
2196+
2197+
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
2198+
/// element bit-width of the shift node, return the maximum possible value.
2199+
std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
2200+
unsigned Depth = 0) const;
21922201

21932202
/// Match a binop + shuffle pyramid that represents a horizontal reduction
21942203
/// over the elements of a vector starting from the EXTRACT_VECTOR_ELT node /p

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 92 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
#include "llvm/CodeGen/ValueTypes.h"
4949
#include "llvm/CodeGenTypes/MachineValueType.h"
5050
#include "llvm/IR/Constant.h"
51-
#include "llvm/IR/ConstantRange.h"
5251
#include "llvm/IR/Constants.h"
5352
#include "llvm/IR/DataLayout.h"
5453
#include "llvm/IR/DebugInfoMetadata.h"
@@ -3009,102 +3008,117 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
30093008
return SDValue();
30103009
}
30113010

3012-
const APInt *
3013-
SelectionDAG::getValidShiftAmountConstant(SDValue V,
3014-
const APInt &DemandedElts) const {
3011+
std::optional<ConstantRange>
3012+
SelectionDAG::getValidShiftAmountRange(SDValue V, const APInt &DemandedElts,
3013+
unsigned Depth) const {
30153014
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
30163015
V.getOpcode() == ISD::SRA) &&
30173016
"Unknown shift node");
3017+
// Shifting more than the bitwidth is not valid.
30183018
unsigned BitWidth = V.getScalarValueSizeInBits();
3019-
if (ConstantSDNode *SA = isConstOrConstSplat(V.getOperand(1), DemandedElts)) {
3020-
// Shifting more than the bitwidth is not valid.
3021-
const APInt &ShAmt = SA->getAPIntValue();
3022-
if (ShAmt.ult(BitWidth))
3023-
return &ShAmt;
3019+
3020+
if (auto *Cst = dyn_cast<ConstantSDNode>(V.getOperand(1))) {
3021+
const APInt &ShAmt = Cst->getAPIntValue();
3022+
if (ShAmt.uge(BitWidth))
3023+
return std::nullopt;
3024+
return ConstantRange(ShAmt);
30243025
}
3025-
return nullptr;
3026+
3027+
if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
3028+
const APInt *MinAmt = nullptr, *MaxAmt = nullptr;
3029+
for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
3030+
if (!DemandedElts[i])
3031+
continue;
3032+
auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
3033+
if (!SA) {
3034+
MinAmt = MaxAmt = nullptr;
3035+
break;
3036+
}
3037+
const APInt &ShAmt = SA->getAPIntValue();
3038+
if (ShAmt.uge(BitWidth))
3039+
return std::nullopt;
3040+
if (!MinAmt || MinAmt->ugt(ShAmt))
3041+
MinAmt = &ShAmt;
3042+
if (!MaxAmt || MaxAmt->ult(ShAmt))
3043+
MaxAmt = &ShAmt;
3044+
}
3045+
assert(((!MinAmt && !MaxAmt) || (MinAmt && MaxAmt)) &&
3046+
"Failed to find matching min/max shift amounts");
3047+
if (MinAmt && MaxAmt)
3048+
return ConstantRange(*MinAmt, *MaxAmt + 1);
3049+
}
3050+
3051+
// Use computeKnownBits to find a hidden constant/knownbits (usually type
3052+
// legalized). e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
3053+
KnownBits KnownAmt = computeKnownBits(V.getOperand(1), DemandedElts, Depth);
3054+
if (KnownAmt.getMaxValue().ult(BitWidth))
3055+
return ConstantRange::fromKnownBits(KnownAmt, /*IsSigned=*/false);
3056+
3057+
return std::nullopt;
30263058
}
30273059

3028-
const APInt *SelectionDAG::getValidShiftAmountConstant(SDValue V) const {
3060+
std::optional<uint64_t>
3061+
SelectionDAG::getValidShiftAmount(SDValue V, const APInt &DemandedElts,
3062+
unsigned Depth) const {
3063+
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
3064+
V.getOpcode() == ISD::SRA) &&
3065+
"Unknown shift node");
3066+
if (std::optional<ConstantRange> AmtRange =
3067+
getValidShiftAmountRange(V, DemandedElts, Depth))
3068+
if (const APInt *ShAmt = AmtRange->getSingleElement())
3069+
return ShAmt->getZExtValue();
3070+
return std::nullopt;
3071+
}
3072+
3073+
std::optional<uint64_t>
3074+
SelectionDAG::getValidShiftAmount(SDValue V, unsigned Depth) const {
30293075
EVT VT = V.getValueType();
30303076
APInt DemandedElts = VT.isFixedLengthVector()
30313077
? APInt::getAllOnes(VT.getVectorNumElements())
30323078
: APInt(1, 1);
3033-
return getValidShiftAmountConstant(V, DemandedElts);
3079+
return getValidShiftAmount(V, DemandedElts, Depth);
30343080
}
30353081

3036-
const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(
3037-
SDValue V, const APInt &DemandedElts) const {
3082+
std::optional<uint64_t>
3083+
SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
3084+
unsigned Depth) const {
30383085
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
30393086
V.getOpcode() == ISD::SRA) &&
30403087
"Unknown shift node");
3041-
if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
3042-
return ValidAmt;
3043-
unsigned BitWidth = V.getScalarValueSizeInBits();
3044-
auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
3045-
if (!BV)
3046-
return nullptr;
3047-
const APInt *MinShAmt = nullptr;
3048-
for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
3049-
if (!DemandedElts[i])
3050-
continue;
3051-
auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
3052-
if (!SA)
3053-
return nullptr;
3054-
// Shifting more than the bitwidth is not valid.
3055-
const APInt &ShAmt = SA->getAPIntValue();
3056-
if (ShAmt.uge(BitWidth))
3057-
return nullptr;
3058-
if (MinShAmt && MinShAmt->ule(ShAmt))
3059-
continue;
3060-
MinShAmt = &ShAmt;
3061-
}
3062-
return MinShAmt;
3088+
if (std::optional<ConstantRange> AmtRange =
3089+
getValidShiftAmountRange(V, DemandedElts, Depth))
3090+
return AmtRange->getUnsignedMin().getZExtValue();
3091+
return std::nullopt;
30633092
}
30643093

3065-
const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(SDValue V) const {
3094+
std::optional<uint64_t>
3095+
SelectionDAG::getValidMinimumShiftAmount(SDValue V, unsigned Depth) const {
30663096
EVT VT = V.getValueType();
30673097
APInt DemandedElts = VT.isFixedLengthVector()
30683098
? APInt::getAllOnes(VT.getVectorNumElements())
30693099
: APInt(1, 1);
3070-
return getValidMinimumShiftAmountConstant(V, DemandedElts);
3100+
return getValidMinimumShiftAmount(V, DemandedElts, Depth);
30713101
}
30723102

3073-
const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
3074-
SDValue V, const APInt &DemandedElts) const {
3103+
std::optional<uint64_t>
3104+
SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
3105+
unsigned Depth) const {
30753106
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
30763107
V.getOpcode() == ISD::SRA) &&
30773108
"Unknown shift node");
3078-
if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
3079-
return ValidAmt;
3080-
unsigned BitWidth = V.getScalarValueSizeInBits();
3081-
auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
3082-
if (!BV)
3083-
return nullptr;
3084-
const APInt *MaxShAmt = nullptr;
3085-
for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
3086-
if (!DemandedElts[i])
3087-
continue;
3088-
auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
3089-
if (!SA)
3090-
return nullptr;
3091-
// Shifting more than the bitwidth is not valid.
3092-
const APInt &ShAmt = SA->getAPIntValue();
3093-
if (ShAmt.uge(BitWidth))
3094-
return nullptr;
3095-
if (MaxShAmt && MaxShAmt->uge(ShAmt))
3096-
continue;
3097-
MaxShAmt = &ShAmt;
3098-
}
3099-
return MaxShAmt;
3109+
if (std::optional<ConstantRange> AmtRange =
3110+
getValidShiftAmountRange(V, DemandedElts, Depth))
3111+
return AmtRange->getUnsignedMax().getZExtValue();
3112+
return std::nullopt;
31003113
}
31013114

3102-
const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(SDValue V) const {
3115+
std::optional<uint64_t>
3116+
SelectionDAG::getValidMaximumShiftAmount(SDValue V, unsigned Depth) const {
31033117
EVT VT = V.getValueType();
31043118
APInt DemandedElts = VT.isFixedLengthVector()
31053119
? APInt::getAllOnes(VT.getVectorNumElements())
31063120
: APInt(1, 1);
3107-
return getValidMaximumShiftAmountConstant(V, DemandedElts);
3121+
return getValidMaximumShiftAmount(V, DemandedElts, Depth);
31083122
}
31093123

31103124
/// Determine which bits of Op are known to be either zero or one and return
@@ -3569,9 +3583,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
35693583
Known = KnownBits::shl(Known, Known2, NUW, NSW, ShAmtNonZero);
35703584

35713585
// Minimum shift low bits are known zero.
3572-
if (const APInt *ShMinAmt =
3573-
getValidMinimumShiftAmountConstant(Op, DemandedElts))
3574-
Known.Zero.setLowBits(ShMinAmt->getZExtValue());
3586+
if (std::optional<uint64_t> ShMinAmt =
3587+
getValidMinimumShiftAmount(Op, DemandedElts, Depth + 1))
3588+
Known.Zero.setLowBits(*ShMinAmt);
35753589
break;
35763590
}
35773591
case ISD::SRL:
@@ -3581,9 +3595,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
35813595
Op->getFlags().hasExact());
35823596

35833597
// Minimum shift high bits are known zero.
3584-
if (const APInt *ShMinAmt =
3585-
getValidMinimumShiftAmountConstant(Op, DemandedElts))
3586-
Known.Zero.setHighBits(ShMinAmt->getZExtValue());
3598+
if (std::optional<uint64_t> ShMinAmt =
3599+
getValidMinimumShiftAmount(Op, DemandedElts, Depth + 1))
3600+
Known.Zero.setHighBits(*ShMinAmt);
35873601
break;
35883602
case ISD::SRA:
35893603
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
@@ -4587,17 +4601,17 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
45874601
case ISD::SRA:
45884602
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
45894603
// SRA X, C -> adds C sign bits.
4590-
if (const APInt *ShAmt =
4591-
getValidMinimumShiftAmountConstant(Op, DemandedElts))
4592-
Tmp = std::min<uint64_t>(Tmp + ShAmt->getZExtValue(), VTBits);
4604+
if (std::optional<uint64_t> ShAmt =
4605+
getValidMinimumShiftAmount(Op, DemandedElts, Depth + 1))
4606+
Tmp = std::min<uint64_t>(Tmp + *ShAmt, VTBits);
45934607
return Tmp;
45944608
case ISD::SHL:
4595-
if (const APInt *ShAmt =
4596-
getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
4609+
if (std::optional<uint64_t> ShAmt =
4610+
getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
45974611
// shl destroys sign bits, ensure it doesn't shift out all sign bits.
45984612
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
4599-
if (ShAmt->ult(Tmp))
4600-
return Tmp - ShAmt->getZExtValue();
4613+
if (*ShAmt < Tmp)
4614+
return Tmp - *ShAmt;
46014615
}
46024616
break;
46034617
case ISD::AND:
@@ -5270,7 +5284,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
52705284
case ISD::SRL:
52715285
case ISD::SRA:
52725286
// If the max shift amount isn't in range, then the shift can create poison.
5273-
return !getValidMaximumShiftAmountConstant(Op, DemandedElts);
5287+
return !getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1);
52745288

52755289
case ISD::SCALAR_TO_VECTOR:
52765290
// Check if we demand any upper (undef) elements.

0 commit comments

Comments
 (0)