-
Notifications
You must be signed in to change notification settings - Fork 14.2k
[DAG] Replace getValid*ShiftAmountConstant helpers with getValid*ShiftAmount helpers to support KnownBits analysis #93182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-x86 @llvm/pr-subscribers-backend-powerpc Author: Simon Pilgrim (RKSimon) ChangesThe getValidShiftAmountConstant/getValidMinimumShiftAmountConstant/getValidMaximumShiftAmountConstant helpers only worked with constant shift amounts, which could be problematic after type legalization (e.g. v2i64 might be split into v4i32 on some targets such as 32-bit x86, Thumb2 MVE). This patch proposes we generalize these helpers to work with KnownBits if a scalar/buildvector constant isn't available. Most restrictions are the same - the helper fails if any shift amount is out of bounds, getValidShiftConstant must be a specific constant uniform etc. However, getValidMinimumShiftAmount/getValidMaximumShiftAmount now can return bounds values that aren't values in the actual data, as they are based off the common KnownBits of every vector element. This addresses feedback on #92096 Full diff: https://github.com/llvm/llvm-project/pull/93182.diff 5 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 96a6270690468..95afbeb5dd6ec 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -2159,36 +2159,38 @@ class SelectionDAG {
/// splatted value it will return SDValue().
SDValue getSplatValue(SDValue V, bool LegalTypes = false);
- /// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
+ /// If a SHL/SRA/SRL node \p V has an uniform shift amount
/// that is less than the element bit-width of the shift node, return it.
- const APInt *getValidShiftAmountConstant(SDValue V,
- const APInt &DemandedElts) const;
+ std::optional<uint64_t> getValidShiftAmount(SDValue V,
+ const APInt &DemandedElts,
+ unsigned Depth = 0) const;
- /// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
+ /// If a SHL/SRA/SRL node \p V has an uniform shift amount
/// that is less than the element bit-width of the shift node, return it.
- const APInt *getValidShiftAmountConstant(SDValue V) const;
-
- /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
- /// than the element bit-width of the shift node, return the minimum value.
- const APInt *
- getValidMinimumShiftAmountConstant(SDValue V,
- const APInt &DemandedElts) const;
-
- /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
- /// than the element bit-width of the shift node, return the minimum value.
- const APInt *
- getValidMinimumShiftAmountConstant(SDValue V) const;
-
- /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
- /// than the element bit-width of the shift node, return the maximum value.
- const APInt *
- getValidMaximumShiftAmountConstant(SDValue V,
- const APInt &DemandedElts) const;
-
- /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
- /// than the element bit-width of the shift node, return the maximum value.
- const APInt *
- getValidMaximumShiftAmountConstant(SDValue V) const;
+ std::optional<uint64_t> getValidShiftAmount(SDValue V,
+ unsigned Depth = 0) const;
+
+ /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+ /// element bit-width of the shift node, return the minimum possible value.
+ std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
+ const APInt &DemandedElts,
+ unsigned Depth = 0) const;
+
+ /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+ /// element bit-width of the shift node, return the minimum possible value.
+ std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
+ unsigned Depth = 0) const;
+
+ /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+ /// element bit-width of the shift node, return the maximum possible value.
+ std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
+ const APInt &DemandedElts,
+ unsigned Depth = 0) const;
+
+ /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+ /// element bit-width of the shift node, return the maximum possible value.
+ std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
+ unsigned Depth = 0) const;
/// Match a binop + shuffle pyramid that represents a horizontal reduction
/// over the elements of a vector starting from the EXTRACT_VECTOR_ELT node /p
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index b05649c6ce955..b71b496c0aa84 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3009,9 +3009,9 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
return SDValue();
}
-const APInt *
-SelectionDAG::getValidShiftAmountConstant(SDValue V,
- const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidShiftAmount(SDValue V, const APInt &DemandedElts,
+ unsigned Depth) const {
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
V.getOpcode() == ISD::SRA) &&
"Unknown shift node");
@@ -3020,91 +3020,111 @@ SelectionDAG::getValidShiftAmountConstant(SDValue V,
// Shifting more than the bitwidth is not valid.
const APInt &ShAmt = SA->getAPIntValue();
if (ShAmt.ult(BitWidth))
- return &ShAmt;
+ return ShAmt.getZExtValue();
+ } else {
+ KnownBits KnownAmt =
+ computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+ if (KnownAmt.isConstant() && KnownAmt.getConstant().ult(BitWidth))
+ return KnownAmt.getConstant().getZExtValue();
}
- return nullptr;
+ return std::nullopt;
}
-const APInt *SelectionDAG::getValidShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidShiftAmount(SDValue V, unsigned Depth) const {
EVT VT = V.getValueType();
APInt DemandedElts = VT.isFixedLengthVector()
? APInt::getAllOnes(VT.getVectorNumElements())
: APInt(1, 1);
- return getValidShiftAmountConstant(V, DemandedElts);
+ return getValidShiftAmount(V, DemandedElts, Depth);
}
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(
- SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
+ unsigned Depth) const {
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
V.getOpcode() == ISD::SRA) &&
"Unknown shift node");
- if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
- return ValidAmt;
unsigned BitWidth = V.getScalarValueSizeInBits();
- auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
- if (!BV)
- return nullptr;
- const APInt *MinShAmt = nullptr;
- for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
- if (!DemandedElts[i])
- continue;
- auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
- if (!SA)
- return nullptr;
- // Shifting more than the bitwidth is not valid.
- const APInt &ShAmt = SA->getAPIntValue();
- if (ShAmt.uge(BitWidth))
- return nullptr;
- if (MinShAmt && MinShAmt->ule(ShAmt))
- continue;
- MinShAmt = &ShAmt;
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+ const APInt *MinShAmt = nullptr;
+ for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+ if (!DemandedElts[i])
+ continue;
+ auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+ if (!SA) {
+ MinShAmt = nullptr;
+ break;
+ }
+ // Shifting more than the bitwidth is not valid.
+ const APInt &ShAmt = SA->getAPIntValue();
+ if (ShAmt.uge(BitWidth))
+ return std::nullopt;
+ if (MinShAmt && MinShAmt->ule(ShAmt))
+ continue;
+ MinShAmt = &ShAmt;
+ }
+ if (MinShAmt)
+ return MinShAmt->getZExtValue();
}
- return MinShAmt;
+ KnownBits KnownAmt =
+ computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+ if (KnownAmt.getMaxValue().ult(BitWidth))
+ return KnownAmt.getMinValue().getZExtValue();
+ return std::nullopt;
}
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, unsigned Depth) const {
EVT VT = V.getValueType();
APInt DemandedElts = VT.isFixedLengthVector()
? APInt::getAllOnes(VT.getVectorNumElements())
: APInt(1, 1);
- return getValidMinimumShiftAmountConstant(V, DemandedElts);
+ return getValidMinimumShiftAmount(V, DemandedElts, Depth);
}
-const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
- SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
+ unsigned Depth) const {
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
V.getOpcode() == ISD::SRA) &&
"Unknown shift node");
- if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
- return ValidAmt;
unsigned BitWidth = V.getScalarValueSizeInBits();
- auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
- if (!BV)
- return nullptr;
- const APInt *MaxShAmt = nullptr;
- for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
- if (!DemandedElts[i])
- continue;
- auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
- if (!SA)
- return nullptr;
- // Shifting more than the bitwidth is not valid.
- const APInt &ShAmt = SA->getAPIntValue();
- if (ShAmt.uge(BitWidth))
- return nullptr;
- if (MaxShAmt && MaxShAmt->uge(ShAmt))
- continue;
- MaxShAmt = &ShAmt;
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+ const APInt *MaxShAmt = nullptr;
+ for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+ if (!DemandedElts[i])
+ continue;
+ auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+ if (!SA) {
+ MaxShAmt = nullptr;
+ break;
+ }
+ // Shifting more than the bitwidth is not valid.
+ const APInt &ShAmt = SA->getAPIntValue();
+ if (ShAmt.uge(BitWidth))
+ return std::nullopt;
+ if (MaxShAmt && MaxShAmt->uge(ShAmt))
+ continue;
+ MaxShAmt = &ShAmt;
+ }
+ if (MaxShAmt)
+ return MaxShAmt->getZExtValue();
}
- return MaxShAmt;
+ KnownBits KnownAmt =
+ computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+ if (KnownAmt.getMaxValue().ult(BitWidth))
+ return KnownAmt.getMaxValue().getZExtValue();
+ return std::nullopt;
}
-const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMaximumShiftAmount(SDValue V, unsigned Depth) const {
EVT VT = V.getValueType();
APInt DemandedElts = VT.isFixedLengthVector()
? APInt::getAllOnes(VT.getVectorNumElements())
: APInt(1, 1);
- return getValidMaximumShiftAmountConstant(V, DemandedElts);
+ return getValidMaximumShiftAmount(V, DemandedElts, Depth);
}
/// Determine which bits of Op are known to be either zero or one and return
@@ -3569,9 +3589,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
Known = KnownBits::shl(Known, Known2, NUW, NSW, ShAmtNonZero);
// Minimum shift low bits are known zero.
- if (const APInt *ShMinAmt =
- getValidMinimumShiftAmountConstant(Op, DemandedElts))
- Known.Zero.setLowBits(ShMinAmt->getZExtValue());
+ if (std::optional<uint64_t> ShMinAmt =
+ getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+ Known.Zero.setLowBits(*ShMinAmt);
break;
}
case ISD::SRL:
@@ -3581,9 +3601,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
Op->getFlags().hasExact());
// Minimum shift high bits are known zero.
- if (const APInt *ShMinAmt =
- getValidMinimumShiftAmountConstant(Op, DemandedElts))
- Known.Zero.setHighBits(ShMinAmt->getZExtValue());
+ if (std::optional<uint64_t> ShMinAmt =
+ getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+ Known.Zero.setHighBits(*ShMinAmt);
break;
case ISD::SRA:
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
@@ -4587,17 +4607,17 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
case ISD::SRA:
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
// SRA X, C -> adds C sign bits.
- if (const APInt *ShAmt =
- getValidMinimumShiftAmountConstant(Op, DemandedElts))
- Tmp = std::min<uint64_t>(Tmp + ShAmt->getZExtValue(), VTBits);
+ if (std::optional<uint64_t> ShAmt =
+ getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+ Tmp = std::min<uint64_t>(Tmp + *ShAmt, VTBits);
return Tmp;
case ISD::SHL:
- if (const APInt *ShAmt =
- getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
+ if (std::optional<uint64_t> ShAmt =
+ getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
// shl destroys sign bits, ensure it doesn't shift out all sign bits.
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
- if (ShAmt->ult(Tmp))
- return Tmp - ShAmt->getZExtValue();
+ if (*ShAmt < Tmp)
+ return Tmp - *ShAmt;
}
break;
case ISD::AND:
@@ -5270,7 +5290,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
case ISD::SRL:
case ISD::SRA:
// If the max shift amount isn't in range, then the shift can create poison.
- return !getValidMaximumShiftAmountConstant(Op, DemandedElts);
+ return !getValidMaximumShiftAmount(Op, DemandedElts, Depth);
case ISD::SCALAR_TO_VECTOR:
// Check if we demand any upper (undef) elements.
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 37c72339fe295..dfcd5439b8a9d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -796,10 +796,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
case ISD::SHL: {
// If we are only demanding sign bits then we can use the shift source
// directly.
- if (const APInt *MaxSA =
- DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
+ if (std::optional<uint64_t> MaxSA =
+ DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
SDValue Op0 = Op.getOperand(0);
- unsigned ShAmt = MaxSA->getZExtValue();
+ unsigned ShAmt = *MaxSA;
unsigned NumSignBits =
DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
@@ -1789,9 +1789,9 @@ bool TargetLowering::SimplifyDemandedBits(
// TODO - support non-uniform vector amounts.
if (InnerOp.getOpcode() == ISD::SRL && Op0.hasOneUse() &&
InnerOp.hasOneUse()) {
- if (const APInt *SA2 =
- TLO.DAG.getValidShiftAmountConstant(InnerOp, DemandedElts)) {
- unsigned InnerShAmt = SA2->getZExtValue();
+ if (std::optional<uint64_t> SA2 = TLO.DAG.getValidShiftAmount(
+ InnerOp, DemandedElts, Depth + 1)) {
+ unsigned InnerShAmt = *SA2;
if (InnerShAmt < ShAmt && InnerShAmt < InnerBits &&
DemandedBits.getActiveBits() <=
(InnerBits - InnerShAmt + ShAmt) &&
@@ -1918,9 +1918,9 @@ bool TargetLowering::SimplifyDemandedBits(
// If we are only demanding sign bits then we can use the shift source
// directly.
- if (const APInt *MaxSA =
- TLO.DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
- unsigned ShAmt = MaxSA->getZExtValue();
+ if (std::optional<uint64_t> MaxSA =
+ TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
+ unsigned ShAmt = *MaxSA;
unsigned NumSignBits =
TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
@@ -2598,11 +2598,11 @@ bool TargetLowering::SimplifyDemandedBits(
break;
if (Src.getNode()->hasOneUse()) {
- const APInt *ShAmtC =
- TLO.DAG.getValidShiftAmountConstant(Src, DemandedElts);
- if (!ShAmtC || ShAmtC->uge(BitWidth))
+ std::optional<uint64_t> ShAmtC =
+ TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 1);
+ if (!ShAmtC || *ShAmtC >= BitWidth)
break;
- uint64_t ShVal = ShAmtC->getZExtValue();
+ uint64_t ShVal = *ShAmtC;
APInt HighBits =
APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth);
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 215cbc308e43d..fd99f0e345d14 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -20490,7 +20490,7 @@ static SDValue matchTruncateWithPACK(unsigned &PackOpcode, EVT DstVT,
// the truncation then we can use PACKSS by converting the srl to a sra.
// SimplifyDemandedBits often relaxes sra to srl so we need to reverse it.
if (In.getOpcode() == ISD::SRL && In->hasOneUse())
- if (const APInt *ShAmt = DAG.getValidShiftAmountConstant(In)) {
+ if (std::optional<uint64_t> ShAmt = DAG.getValidShiftAmount(In)) {
if (*ShAmt == MinSignBits) {
PackOpcode = X86ISD::PACKSS;
return DAG.getNode(ISD::SRA, DL, SrcVT, In->ops());
diff --git a/llvm/test/CodeGen/PowerPC/pr44183.ll b/llvm/test/CodeGen/PowerPC/pr44183.ll
index 4d2c81c35b7fe..dc3e129922971 100644
--- a/llvm/test/CodeGen/PowerPC/pr44183.ll
+++ b/llvm/test/CodeGen/PowerPC/pr44183.ll
@@ -12,13 +12,12 @@ define void @_ZN1m1nEv(ptr %this) local_unnamed_addr nounwind align 2 {
; CHECK-NEXT: mflr r0
; CHECK-NEXT: std r30, -16(r1) # 8-byte Folded Spill
; CHECK-NEXT: stdu r1, -48(r1)
-; CHECK-NEXT: std r0, 64(r1)
; CHECK-NEXT: mr r30, r3
-; CHECK-NEXT: ld r3, 8(r3)
+; CHECK-NEXT: std r0, 64(r1)
+; CHECK-NEXT: lwz r3, 8(r3)
; CHECK-NEXT: lwz r4, 36(r30)
-; CHECK-NEXT: rldicl r3, r3, 60, 4
+; CHECK-NEXT: rlwinm r3, r3, 27, 0, 0
; CHECK-NEXT: clrlwi r4, r4, 31
-; CHECK-NEXT: slwi r3, r3, 31
; CHECK-NEXT: rlwimi r4, r3, 0, 0, 0
; CHECK-NEXT: bl _ZN1llsE1d
; CHECK-NEXT: nop
|
@llvm/pr-subscribers-llvm-selectiondag Author: Simon Pilgrim (RKSimon) ChangesThe getValidShiftAmountConstant/getValidMinimumShiftAmountConstant/getValidMaximumShiftAmountConstant helpers only worked with constant shift amounts, which could be problematic after type legalization (e.g. v2i64 might be split into v4i32 on some targets such as 32-bit x86, Thumb2 MVE). This patch proposes we generalize these helpers to work with KnownBits if a scalar/buildvector constant isn't available. Most restrictions are the same - the helper fails if any shift amount is out of bounds, getValidShiftConstant must be a specific constant uniform etc. However, getValidMinimumShiftAmount/getValidMaximumShiftAmount now can return bounds values that aren't values in the actual data, as they are based off the common KnownBits of every vector element. This addresses feedback on #92096 Full diff: https://github.com/llvm/llvm-project/pull/93182.diff 5 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 96a6270690468..95afbeb5dd6ec 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -2159,36 +2159,38 @@ class SelectionDAG {
/// splatted value it will return SDValue().
SDValue getSplatValue(SDValue V, bool LegalTypes = false);
- /// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
+ /// If a SHL/SRA/SRL node \p V has an uniform shift amount
/// that is less than the element bit-width of the shift node, return it.
- const APInt *getValidShiftAmountConstant(SDValue V,
- const APInt &DemandedElts) const;
+ std::optional<uint64_t> getValidShiftAmount(SDValue V,
+ const APInt &DemandedElts,
+ unsigned Depth = 0) const;
- /// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
+ /// If a SHL/SRA/SRL node \p V has an uniform shift amount
/// that is less than the element bit-width of the shift node, return it.
- const APInt *getValidShiftAmountConstant(SDValue V) const;
-
- /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
- /// than the element bit-width of the shift node, return the minimum value.
- const APInt *
- getValidMinimumShiftAmountConstant(SDValue V,
- const APInt &DemandedElts) const;
-
- /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
- /// than the element bit-width of the shift node, return the minimum value.
- const APInt *
- getValidMinimumShiftAmountConstant(SDValue V) const;
-
- /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
- /// than the element bit-width of the shift node, return the maximum value.
- const APInt *
- getValidMaximumShiftAmountConstant(SDValue V,
- const APInt &DemandedElts) const;
-
- /// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
- /// than the element bit-width of the shift node, return the maximum value.
- const APInt *
- getValidMaximumShiftAmountConstant(SDValue V) const;
+ std::optional<uint64_t> getValidShiftAmount(SDValue V,
+ unsigned Depth = 0) const;
+
+ /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+ /// element bit-width of the shift node, return the minimum possible value.
+ std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
+ const APInt &DemandedElts,
+ unsigned Depth = 0) const;
+
+ /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+ /// element bit-width of the shift node, return the minimum possible value.
+ std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
+ unsigned Depth = 0) const;
+
+ /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+ /// element bit-width of the shift node, return the maximum possible value.
+ std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
+ const APInt &DemandedElts,
+ unsigned Depth = 0) const;
+
+ /// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
+ /// element bit-width of the shift node, return the maximum possible value.
+ std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
+ unsigned Depth = 0) const;
/// Match a binop + shuffle pyramid that represents a horizontal reduction
/// over the elements of a vector starting from the EXTRACT_VECTOR_ELT node /p
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index b05649c6ce955..b71b496c0aa84 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3009,9 +3009,9 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
return SDValue();
}
-const APInt *
-SelectionDAG::getValidShiftAmountConstant(SDValue V,
- const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidShiftAmount(SDValue V, const APInt &DemandedElts,
+ unsigned Depth) const {
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
V.getOpcode() == ISD::SRA) &&
"Unknown shift node");
@@ -3020,91 +3020,111 @@ SelectionDAG::getValidShiftAmountConstant(SDValue V,
// Shifting more than the bitwidth is not valid.
const APInt &ShAmt = SA->getAPIntValue();
if (ShAmt.ult(BitWidth))
- return &ShAmt;
+ return ShAmt.getZExtValue();
+ } else {
+ KnownBits KnownAmt =
+ computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+ if (KnownAmt.isConstant() && KnownAmt.getConstant().ult(BitWidth))
+ return KnownAmt.getConstant().getZExtValue();
}
- return nullptr;
+ return std::nullopt;
}
-const APInt *SelectionDAG::getValidShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidShiftAmount(SDValue V, unsigned Depth) const {
EVT VT = V.getValueType();
APInt DemandedElts = VT.isFixedLengthVector()
? APInt::getAllOnes(VT.getVectorNumElements())
: APInt(1, 1);
- return getValidShiftAmountConstant(V, DemandedElts);
+ return getValidShiftAmount(V, DemandedElts, Depth);
}
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(
- SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
+ unsigned Depth) const {
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
V.getOpcode() == ISD::SRA) &&
"Unknown shift node");
- if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
- return ValidAmt;
unsigned BitWidth = V.getScalarValueSizeInBits();
- auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
- if (!BV)
- return nullptr;
- const APInt *MinShAmt = nullptr;
- for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
- if (!DemandedElts[i])
- continue;
- auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
- if (!SA)
- return nullptr;
- // Shifting more than the bitwidth is not valid.
- const APInt &ShAmt = SA->getAPIntValue();
- if (ShAmt.uge(BitWidth))
- return nullptr;
- if (MinShAmt && MinShAmt->ule(ShAmt))
- continue;
- MinShAmt = &ShAmt;
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+ const APInt *MinShAmt = nullptr;
+ for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+ if (!DemandedElts[i])
+ continue;
+ auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+ if (!SA) {
+ MinShAmt = nullptr;
+ break;
+ }
+ // Shifting more than the bitwidth is not valid.
+ const APInt &ShAmt = SA->getAPIntValue();
+ if (ShAmt.uge(BitWidth))
+ return std::nullopt;
+ if (MinShAmt && MinShAmt->ule(ShAmt))
+ continue;
+ MinShAmt = &ShAmt;
+ }
+ if (MinShAmt)
+ return MinShAmt->getZExtValue();
}
- return MinShAmt;
+ KnownBits KnownAmt =
+ computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+ if (KnownAmt.getMaxValue().ult(BitWidth))
+ return KnownAmt.getMinValue().getZExtValue();
+ return std::nullopt;
}
-const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMinimumShiftAmount(SDValue V, unsigned Depth) const {
EVT VT = V.getValueType();
APInt DemandedElts = VT.isFixedLengthVector()
? APInt::getAllOnes(VT.getVectorNumElements())
: APInt(1, 1);
- return getValidMinimumShiftAmountConstant(V, DemandedElts);
+ return getValidMinimumShiftAmount(V, DemandedElts, Depth);
}
-const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
- SDValue V, const APInt &DemandedElts) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
+ unsigned Depth) const {
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
V.getOpcode() == ISD::SRA) &&
"Unknown shift node");
- if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
- return ValidAmt;
unsigned BitWidth = V.getScalarValueSizeInBits();
- auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
- if (!BV)
- return nullptr;
- const APInt *MaxShAmt = nullptr;
- for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
- if (!DemandedElts[i])
- continue;
- auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
- if (!SA)
- return nullptr;
- // Shifting more than the bitwidth is not valid.
- const APInt &ShAmt = SA->getAPIntValue();
- if (ShAmt.uge(BitWidth))
- return nullptr;
- if (MaxShAmt && MaxShAmt->uge(ShAmt))
- continue;
- MaxShAmt = &ShAmt;
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
+ const APInt *MaxShAmt = nullptr;
+ for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
+ if (!DemandedElts[i])
+ continue;
+ auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
+ if (!SA) {
+ MaxShAmt = nullptr;
+ break;
+ }
+ // Shifting more than the bitwidth is not valid.
+ const APInt &ShAmt = SA->getAPIntValue();
+ if (ShAmt.uge(BitWidth))
+ return std::nullopt;
+ if (MaxShAmt && MaxShAmt->uge(ShAmt))
+ continue;
+ MaxShAmt = &ShAmt;
+ }
+ if (MaxShAmt)
+ return MaxShAmt->getZExtValue();
}
- return MaxShAmt;
+ KnownBits KnownAmt =
+ computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
+ if (KnownAmt.getMaxValue().ult(BitWidth))
+ return KnownAmt.getMaxValue().getZExtValue();
+ return std::nullopt;
}
-const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(SDValue V) const {
+std::optional<uint64_t>
+SelectionDAG::getValidMaximumShiftAmount(SDValue V, unsigned Depth) const {
EVT VT = V.getValueType();
APInt DemandedElts = VT.isFixedLengthVector()
? APInt::getAllOnes(VT.getVectorNumElements())
: APInt(1, 1);
- return getValidMaximumShiftAmountConstant(V, DemandedElts);
+ return getValidMaximumShiftAmount(V, DemandedElts, Depth);
}
/// Determine which bits of Op are known to be either zero or one and return
@@ -3569,9 +3589,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
Known = KnownBits::shl(Known, Known2, NUW, NSW, ShAmtNonZero);
// Minimum shift low bits are known zero.
- if (const APInt *ShMinAmt =
- getValidMinimumShiftAmountConstant(Op, DemandedElts))
- Known.Zero.setLowBits(ShMinAmt->getZExtValue());
+ if (std::optional<uint64_t> ShMinAmt =
+ getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+ Known.Zero.setLowBits(*ShMinAmt);
break;
}
case ISD::SRL:
@@ -3581,9 +3601,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
Op->getFlags().hasExact());
// Minimum shift high bits are known zero.
- if (const APInt *ShMinAmt =
- getValidMinimumShiftAmountConstant(Op, DemandedElts))
- Known.Zero.setHighBits(ShMinAmt->getZExtValue());
+ if (std::optional<uint64_t> ShMinAmt =
+ getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+ Known.Zero.setHighBits(*ShMinAmt);
break;
case ISD::SRA:
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
@@ -4587,17 +4607,17 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
case ISD::SRA:
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
// SRA X, C -> adds C sign bits.
- if (const APInt *ShAmt =
- getValidMinimumShiftAmountConstant(Op, DemandedElts))
- Tmp = std::min<uint64_t>(Tmp + ShAmt->getZExtValue(), VTBits);
+ if (std::optional<uint64_t> ShAmt =
+ getValidMinimumShiftAmount(Op, DemandedElts, Depth))
+ Tmp = std::min<uint64_t>(Tmp + *ShAmt, VTBits);
return Tmp;
case ISD::SHL:
- if (const APInt *ShAmt =
- getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
+ if (std::optional<uint64_t> ShAmt =
+ getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
// shl destroys sign bits, ensure it doesn't shift out all sign bits.
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
- if (ShAmt->ult(Tmp))
- return Tmp - ShAmt->getZExtValue();
+ if (*ShAmt < Tmp)
+ return Tmp - *ShAmt;
}
break;
case ISD::AND:
@@ -5270,7 +5290,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
case ISD::SRL:
case ISD::SRA:
// If the max shift amount isn't in range, then the shift can create poison.
- return !getValidMaximumShiftAmountConstant(Op, DemandedElts);
+ return !getValidMaximumShiftAmount(Op, DemandedElts, Depth);
case ISD::SCALAR_TO_VECTOR:
// Check if we demand any upper (undef) elements.
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 37c72339fe295..dfcd5439b8a9d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -796,10 +796,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
case ISD::SHL: {
// If we are only demanding sign bits then we can use the shift source
// directly.
- if (const APInt *MaxSA =
- DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
+ if (std::optional<uint64_t> MaxSA =
+ DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
SDValue Op0 = Op.getOperand(0);
- unsigned ShAmt = MaxSA->getZExtValue();
+ unsigned ShAmt = *MaxSA;
unsigned NumSignBits =
DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
@@ -1789,9 +1789,9 @@ bool TargetLowering::SimplifyDemandedBits(
// TODO - support non-uniform vector amounts.
if (InnerOp.getOpcode() == ISD::SRL && Op0.hasOneUse() &&
InnerOp.hasOneUse()) {
- if (const APInt *SA2 =
- TLO.DAG.getValidShiftAmountConstant(InnerOp, DemandedElts)) {
- unsigned InnerShAmt = SA2->getZExtValue();
+ if (std::optional<uint64_t> SA2 = TLO.DAG.getValidShiftAmount(
+ InnerOp, DemandedElts, Depth + 1)) {
+ unsigned InnerShAmt = *SA2;
if (InnerShAmt < ShAmt && InnerShAmt < InnerBits &&
DemandedBits.getActiveBits() <=
(InnerBits - InnerShAmt + ShAmt) &&
@@ -1918,9 +1918,9 @@ bool TargetLowering::SimplifyDemandedBits(
// If we are only demanding sign bits then we can use the shift source
// directly.
- if (const APInt *MaxSA =
- TLO.DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
- unsigned ShAmt = MaxSA->getZExtValue();
+ if (std::optional<uint64_t> MaxSA =
+ TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
+ unsigned ShAmt = *MaxSA;
unsigned NumSignBits =
TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
@@ -2598,11 +2598,11 @@ bool TargetLowering::SimplifyDemandedBits(
break;
if (Src.getNode()->hasOneUse()) {
- const APInt *ShAmtC =
- TLO.DAG.getValidShiftAmountConstant(Src, DemandedElts);
- if (!ShAmtC || ShAmtC->uge(BitWidth))
+ std::optional<uint64_t> ShAmtC =
+ TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 1);
+ if (!ShAmtC || *ShAmtC >= BitWidth)
break;
- uint64_t ShVal = ShAmtC->getZExtValue();
+ uint64_t ShVal = *ShAmtC;
APInt HighBits =
APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth);
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 215cbc308e43d..fd99f0e345d14 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -20490,7 +20490,7 @@ static SDValue matchTruncateWithPACK(unsigned &PackOpcode, EVT DstVT,
// the truncation then we can use PACKSS by converting the srl to a sra.
// SimplifyDemandedBits often relaxes sra to srl so we need to reverse it.
if (In.getOpcode() == ISD::SRL && In->hasOneUse())
- if (const APInt *ShAmt = DAG.getValidShiftAmountConstant(In)) {
+ if (std::optional<uint64_t> ShAmt = DAG.getValidShiftAmount(In)) {
if (*ShAmt == MinSignBits) {
PackOpcode = X86ISD::PACKSS;
return DAG.getNode(ISD::SRA, DL, SrcVT, In->ops());
diff --git a/llvm/test/CodeGen/PowerPC/pr44183.ll b/llvm/test/CodeGen/PowerPC/pr44183.ll
index 4d2c81c35b7fe..dc3e129922971 100644
--- a/llvm/test/CodeGen/PowerPC/pr44183.ll
+++ b/llvm/test/CodeGen/PowerPC/pr44183.ll
@@ -12,13 +12,12 @@ define void @_ZN1m1nEv(ptr %this) local_unnamed_addr nounwind align 2 {
; CHECK-NEXT: mflr r0
; CHECK-NEXT: std r30, -16(r1) # 8-byte Folded Spill
; CHECK-NEXT: stdu r1, -48(r1)
-; CHECK-NEXT: std r0, 64(r1)
; CHECK-NEXT: mr r30, r3
-; CHECK-NEXT: ld r3, 8(r3)
+; CHECK-NEXT: std r0, 64(r1)
+; CHECK-NEXT: lwz r3, 8(r3)
; CHECK-NEXT: lwz r4, 36(r30)
-; CHECK-NEXT: rldicl r3, r3, 60, 4
+; CHECK-NEXT: rlwinm r3, r3, 27, 0, 0
; CHECK-NEXT: clrlwi r4, r4, 31
-; CHECK-NEXT: slwi r3, r3, 31
; CHECK-NEXT: rlwimi r4, r3, 0, 0, 0
; CHECK-NEXT: bl _ZN1llsE1d
; CHECK-NEXT: nop
|
} else { | ||
KnownBits KnownAmt = | ||
computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1); | ||
if (KnownAmt.isConstant() && KnownAmt.getConstant().ult(BitWidth)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO if its constant only, Depth should start at MaxDepth - 2, otherwise think there is a lot of wasted compute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem we're hitting is that in many cases its not just a bitcast of a constant, its often more complex than that and when trying to use value tracking during legalization we only get one chance, we can't wait for combines to simplify things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, fair enough (maybe mention in a comment? This seems a bit non-intuitive).
b7c6586
to
9ced5e9
Compare
9ced5e9
to
bbe3c19
Compare
any more comments? |
aaee413
to
57b8519
Compare
ping - anything else? |
// Use computeKnownBits to find a hidden constant (usually type legalized). | ||
// e.g. Hidden behind multiple bitcasts/build_vector/casts etc. | ||
KnownBits KnownAmt = | ||
computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think Depth + 1
makes sense here. This isn't a recursive function. If its used by a recursive function, the caller should be setting Depth + 1
.
computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1); | ||
if (KnownAmt.getMaxValue().ult(BitWidth)) | ||
return KnownAmt.getMaxValue().getZExtValue(); | ||
return std::nullopt; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there should a single helper for all of the getValid*ShiftAmount
functions.
Think it could just be getValidShiftAmount
impl but return a KnownBits
(the constant case we can just use KnownBits::makeConstant()
).
Then getValidShiftAmount
returns if the result is constant + less than bitwidth, getValidMinimumShiftAmount
returns the getMinValue(), and vice versa for maximum. Think that will save a lot of dup code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me take a look, I'm wondering if ConstantRange will help us here. KnownBits alone tends to make the min/max values 'fuzzy'.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its a bit of a mixed bag. KnownBits impl is more complete but has to throw away range information that isn't tied to particular bits. ConstantRange has a less complete impl, but obv represents what we want here better.
57b8519
to
3879271
Compare
// Use computeKnownBits to find a hidden constant/knownbits (usually type | ||
// legalized). e.g. Hidden behind multiple bitcasts/build_vector/casts etc. | ||
KnownBits KnownAmt = | ||
computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still think this should be Depth
and the callers should be in charge of managing Depth + 1
(the callers to getValidShiftAmount
/getValidMinimumShiftAmount
/getValidMaximumShiftAmount
that is).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really mind either way - we do pass in the shift node and not the shift amount specifically so I can see both sides of it. This PR was just to pull out a minor tweak of #92096 and has instead turned into a monster :(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its my preference, but not a req.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guess my feeling is most of our recursive functions in DAGCombiner we put the +1
at the initial callsite, so it might be confusing otherwise. Also this prohibits a 0 depth use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK - the depth = 0 case could be an issue in the matchTruncateWithPACK use in X86ISelLowering.cpp - this is currently the only call from a non-recursive function, but there will probably be more in the future.
LGTM (with preference for changing Wait a day or so before pushing so others have a chance to take a look. |
3879271
to
ee0341d
Compare
if (!ShAmtC || ShAmtC->uge(BitWidth)) | ||
std::optional<uint64_t> ShAmtC = | ||
TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 2); | ||
if (!ShAmtC || *ShAmtC >= BitWidth) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why Depth + 2? (here and above)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are 'inner' shifts of shifts. The originals should have used Depth+1 but that was missed.
Thanks everyone, unless anyone has any more concerns I'm intending to commit this over the weekend. |
…tAmount helpers to support KnownBits analysis The getValidShiftAmountConstant/getValidMinimumShiftAmountConstant/getValidMaximumShiftAmountConstant helpers only worked with constant shift amounts, which could be problematic after type legalization (e.g. v2i64 might be split into v4i32 on some targets such as Thumb2 MVE). This patch proposes we generalize these helpers to work with KnownBits if a scalar/buildvector constant isn't available. Most restrictions are the same - the helper fails if any shift amount is out of bounds, getValidShiftConstant must be a specific constant uniform etc. However, getValidMinimumShiftAmount/getValidMaximumShiftAmount now can return bounds values that aren't values in the actual data, as they are based off the common KnownBits of every vector element. This addresses feedback on llvm#92096
…ift amount values.
…rands to the caller
I forgot to account for ConstantRange expecting the upper bound to be non-inclusive
55fb0ca
to
5af5cd1
Compare
As usual when touching SimplifyDemandedBits et al this has managed to expose a bug somewhere else that only shows up in multi stage build bots - currently looking at this. |
…en not poison Since #93182 we can now call computeKnownBits inside getValidMaximumShiftAmount to determine the bounds of the shift amount ensuring that it wasn't poison, meaning if we did freeze the ahift amount, isGuaranteedNotToBeUndefOrPoison would then fail as we can't call computeKnownBits through FREEZE for potentially poison values. I'm still reducing a decent test case but wanted to get the buildbot fix ASAP.
The getValidShiftAmountConstant/getValidMinimumShiftAmountConstant/getValidMaximumShiftAmountConstant helpers only worked with constant shift amounts, which could be problematic after type legalization (e.g. v2i64 might be split into v4i32 on some targets such as 32-bit x86, Thumb2 MVE).
This patch proposes we generalize these helpers to work with KnownBits if a scalar/buildvector constant isn't available.
Most restrictions are the same - the helper fails if any shift amount is out of bounds, getValidShiftConstant must be a specific constant uniform etc.
However, getValidMinimumShiftAmount/getValidMaximumShiftAmount now can return bounds values that aren't values in the actual data, as they are based off the common KnownBits of every vector element.
This addresses feedback on #92096