Skip to content

Commit 5cbd598

Browse files
committed
[DAG] getValidShiftAmountRange - move responsibility of the Depth operands to the caller
1 parent 922fbe8 commit 5cbd598

File tree

2 files changed

+16
-17
lines changed

2 files changed

+16
-17
lines changed

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3050,8 +3050,7 @@ SelectionDAG::getValidShiftAmountRange(SDValue V, const APInt &DemandedElts,
30503050

30513051
// Use computeKnownBits to find a hidden constant/knownbits (usually type
30523052
// legalized). e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
3053-
KnownBits KnownAmt =
3054-
computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
3053+
KnownBits KnownAmt = computeKnownBits(V.getOperand(1), DemandedElts, Depth);
30553054
if (KnownAmt.getMaxValue().ult(BitWidth))
30563055
return ConstantRange::fromKnownBits(KnownAmt, /*IsSigned=*/false);
30573056

@@ -3585,7 +3584,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
35853584

35863585
// Minimum shift low bits are known zero.
35873586
if (std::optional<uint64_t> ShMinAmt =
3588-
getValidMinimumShiftAmount(Op, DemandedElts, Depth))
3587+
getValidMinimumShiftAmount(Op, DemandedElts, Depth + 1))
35893588
Known.Zero.setLowBits(*ShMinAmt);
35903589
break;
35913590
}
@@ -3597,7 +3596,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
35973596

35983597
// Minimum shift high bits are known zero.
35993598
if (std::optional<uint64_t> ShMinAmt =
3600-
getValidMinimumShiftAmount(Op, DemandedElts, Depth))
3599+
getValidMinimumShiftAmount(Op, DemandedElts, Depth + 1))
36013600
Known.Zero.setHighBits(*ShMinAmt);
36023601
break;
36033602
case ISD::SRA:
@@ -4603,12 +4602,12 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
46034602
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
46044603
// SRA X, C -> adds C sign bits.
46054604
if (std::optional<uint64_t> ShAmt =
4606-
getValidMinimumShiftAmount(Op, DemandedElts, Depth))
4605+
getValidMinimumShiftAmount(Op, DemandedElts, Depth + 1))
46074606
Tmp = std::min<uint64_t>(Tmp + *ShAmt, VTBits);
46084607
return Tmp;
46094608
case ISD::SHL:
46104609
if (std::optional<uint64_t> ShAmt =
4611-
getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
4610+
getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
46124611
// shl destroys sign bits, ensure it doesn't shift out all sign bits.
46134612
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
46144613
if (*ShAmt < Tmp)
@@ -5285,7 +5284,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
52855284
case ISD::SRL:
52865285
case ISD::SRA:
52875286
// If the max shift amount isn't in range, then the shift can create poison.
5288-
return !getValidMaximumShiftAmount(Op, DemandedElts, Depth);
5287+
return !getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1);
52895288

52905289
case ISD::SCALAR_TO_VECTOR:
52915290
// Check if we demand any upper (undef) elements.

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,7 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
797797
// If we are only demanding sign bits then we can use the shift source
798798
// directly.
799799
if (std::optional<uint64_t> MaxSA =
800-
DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
800+
DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
801801
SDValue Op0 = Op.getOperand(0);
802802
unsigned ShAmt = *MaxSA;
803803
unsigned NumSignBits =
@@ -1737,7 +1737,7 @@ bool TargetLowering::SimplifyDemandedBits(
17371737
EVT ShiftVT = Op1.getValueType();
17381738

17391739
if (std::optional<uint64_t> KnownSA =
1740-
TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth)) {
1740+
TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
17411741
unsigned ShAmt = *KnownSA;
17421742
if (ShAmt == 0)
17431743
return TLO.CombineTo(Op, Op0);
@@ -1749,7 +1749,7 @@ bool TargetLowering::SimplifyDemandedBits(
17491749
if (Op0.getOpcode() == ISD::SRL) {
17501750
if (!DemandedBits.intersects(APInt::getLowBitsSet(BitWidth, ShAmt))) {
17511751
if (std::optional<uint64_t> InnerSA =
1752-
TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 1)) {
1752+
TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
17531753
unsigned C1 = *InnerSA;
17541754
unsigned Opc = ISD::SHL;
17551755
int Diff = ShAmt - C1;
@@ -1789,7 +1789,7 @@ bool TargetLowering::SimplifyDemandedBits(
17891789
if (InnerOp.getOpcode() == ISD::SRL && Op0.hasOneUse() &&
17901790
InnerOp.hasOneUse()) {
17911791
if (std::optional<uint64_t> SA2 = TLO.DAG.getValidShiftAmount(
1792-
InnerOp, DemandedElts, Depth + 1)) {
1792+
InnerOp, DemandedElts, Depth + 2)) {
17931793
unsigned InnerShAmt = *SA2;
17941794
if (InnerShAmt < ShAmt && InnerShAmt < InnerBits &&
17951795
DemandedBits.getActiveBits() <=
@@ -1918,7 +1918,7 @@ bool TargetLowering::SimplifyDemandedBits(
19181918
// If we are only demanding sign bits then we can use the shift source
19191919
// directly.
19201920
if (std::optional<uint64_t> MaxSA =
1921-
TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth)) {
1921+
TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
19221922
unsigned ShAmt = *MaxSA;
19231923
unsigned NumSignBits =
19241924
TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
@@ -1934,7 +1934,7 @@ bool TargetLowering::SimplifyDemandedBits(
19341934
EVT ShiftVT = Op1.getValueType();
19351935

19361936
if (std::optional<uint64_t> KnownSA =
1937-
TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth)) {
1937+
TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
19381938
unsigned ShAmt = *KnownSA;
19391939
if (ShAmt == 0)
19401940
return TLO.CombineTo(Op, Op0);
@@ -1946,7 +1946,7 @@ bool TargetLowering::SimplifyDemandedBits(
19461946
if (Op0.getOpcode() == ISD::SHL) {
19471947
if (!DemandedBits.intersects(APInt::getHighBitsSet(BitWidth, ShAmt))) {
19481948
if (std::optional<uint64_t> InnerSA =
1949-
TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 1)) {
1949+
TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
19501950
unsigned C1 = *InnerSA;
19511951
unsigned Opc = ISD::SRL;
19521952
int Diff = ShAmt - C1;
@@ -2041,7 +2041,7 @@ bool TargetLowering::SimplifyDemandedBits(
20412041
return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
20422042

20432043
if (std::optional<uint64_t> KnownSA =
2044-
TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth)) {
2044+
TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
20452045
unsigned ShAmt = *KnownSA;
20462046
if (ShAmt == 0)
20472047
return TLO.CombineTo(Op, Op0);
@@ -2050,7 +2050,7 @@ bool TargetLowering::SimplifyDemandedBits(
20502050
// supports sext_inreg.
20512051
if (Op0.getOpcode() == ISD::SHL) {
20522052
if (std::optional<uint64_t> InnerSA =
2053-
TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 1)) {
2053+
TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
20542054
unsigned LowBits = BitWidth - ShAmt;
20552055
EVT ExtVT = EVT::getIntegerVT(*TLO.DAG.getContext(), LowBits);
20562056
if (VT.isVector())
@@ -2596,7 +2596,7 @@ bool TargetLowering::SimplifyDemandedBits(
25962596

25972597
if (Src.getNode()->hasOneUse()) {
25982598
std::optional<uint64_t> ShAmtC =
2599-
TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 1);
2599+
TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 2);
26002600
if (!ShAmtC || *ShAmtC >= BitWidth)
26012601
break;
26022602
uint64_t ShVal = *ShAmtC;

0 commit comments

Comments
 (0)