Skip to content

[SelectionDAG] Use getShiftAmountConstant to simplify code. NFC #80561

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 44 additions & 51 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1096,7 +1096,6 @@ bool TargetLowering::SimplifyDemandedBits(
APInt DemandedBits = OriginalDemandedBits;
APInt DemandedElts = OriginalDemandedElts;
SDLoc dl(Op);
auto &DL = TLO.DAG.getDataLayout();

// Undef operand.
if (Op.isUndef())
Expand Down Expand Up @@ -2288,9 +2287,8 @@ bool TargetLowering::SimplifyDemandedBits(
// the right place.
unsigned ShiftOpcode = NLZ > NTZ ? ISD::SRL : ISD::SHL;
if (!TLO.LegalOperations() || isOperationLegal(ShiftOpcode, VT)) {
EVT ShiftAmtTy = getShiftAmountTy(VT, DL);
unsigned ShiftAmount = NLZ > NTZ ? NLZ - NTZ : NTZ - NLZ;
SDValue ShAmt = TLO.DAG.getConstant(ShiftAmount, dl, ShiftAmtTy);
SDValue ShAmt = TLO.DAG.getShiftAmountConstant(ShiftAmount, VT, dl);
SDValue NewOp = TLO.DAG.getNode(ShiftOpcode, dl, VT, Src, ShAmt);
return TLO.CombineTo(Op, NewOp);
}
Expand Down Expand Up @@ -2330,8 +2328,8 @@ bool TargetLowering::SimplifyDemandedBits(
if (!AlreadySignExtended) {
// Compute the correct shift amount type, which must be getShiftAmountTy
// for scalar types after legalization.
SDValue ShiftAmt = TLO.DAG.getConstant(BitWidth - ExVTBits, dl,
getShiftAmountTy(VT, DL));
SDValue ShiftAmt =
TLO.DAG.getShiftAmountConstant(BitWidth - ExVTBits, VT, dl);
return TLO.CombineTo(Op,
TLO.DAG.getNode(ISD::SHL, dl, VT, Op0, ShiftAmt));
}
Expand Down Expand Up @@ -2574,8 +2572,8 @@ bool TargetLowering::SimplifyDemandedBits(
if (!(HighBits & DemandedBits)) {
// None of the shifted in bits are needed. Add a truncate of the
// shift input, then shift it.
SDValue NewShAmt = TLO.DAG.getConstant(
ShVal, dl, getShiftAmountTy(VT, DL, TLO.LegalTypes()));
SDValue NewShAmt =
TLO.DAG.getShiftAmountConstant(ShVal, VT, dl, TLO.LegalTypes());
SDValue NewTrunc =
TLO.DAG.getNode(ISD::TRUNCATE, dl, VT, Src.getOperand(0));
return TLO.CombineTo(
Expand Down Expand Up @@ -2753,8 +2751,7 @@ bool TargetLowering::SimplifyDemandedBits(
unsigned CTZ = DemandedBits.countr_zero();
ConstantSDNode *C = isConstOrConstSplat(Op.getOperand(1), DemandedElts);
if (C && C->getAPIntValue().countr_zero() == CTZ) {
EVT ShiftAmtTy = getShiftAmountTy(VT, TLO.DAG.getDataLayout());
SDValue AmtC = TLO.DAG.getConstant(CTZ, dl, ShiftAmtTy);
SDValue AmtC = TLO.DAG.getShiftAmountConstant(CTZ, VT, dl);
SDValue Shl = TLO.DAG.getNode(ISD::SHL, dl, VT, Op.getOperand(0), AmtC);
return TLO.CombineTo(Op, Shl);
}
Expand Down Expand Up @@ -2852,9 +2849,9 @@ bool TargetLowering::SimplifyDemandedBits(
return 0;
};

auto foldMul = [&](ISD::NodeType NT, SDValue X, SDValue Y, unsigned ShlAmt) {
EVT ShiftAmtTy = getShiftAmountTy(VT, TLO.DAG.getDataLayout());
SDValue ShlAmtC = TLO.DAG.getConstant(ShlAmt, dl, ShiftAmtTy);
auto foldMul = [&](ISD::NodeType NT, SDValue X, SDValue Y,
unsigned ShlAmt) {
SDValue ShlAmtC = TLO.DAG.getShiftAmountConstant(ShlAmt, VT, dl);
SDValue Shl = TLO.DAG.getNode(ISD::SHL, dl, VT, X, ShlAmtC);
SDValue Res = TLO.DAG.getNode(NT, dl, VT, Y, Shl);
return TLO.CombineTo(Op, Res);
Expand Down Expand Up @@ -4204,9 +4201,8 @@ SDValue TargetLowering::foldSetCCWithBinOp(EVT VT, SDValue N0, SDValue N1,
return SDValue();

// (X - Y) == Y --> X == Y << 1
EVT ShiftVT = getShiftAmountTy(OpVT, DAG.getDataLayout(),
!DCI.isBeforeLegalize());
SDValue One = DAG.getConstant(1, DL, ShiftVT);
SDValue One =
DAG.getShiftAmountConstant(1, OpVT, DL, !DCI.isBeforeLegalize());
SDValue YShl1 = DAG.getNode(ISD::SHL, DL, N1.getValueType(), Y, One);
if (!DCI.isCalledByLegalizer())
DCI.AddToWorklist(YShl1.getNode());
Expand Down Expand Up @@ -5038,34 +5034,35 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
(VT == ShValTy || (isTypeLegal(VT) && VT.bitsLE(ShValTy))) &&
N0.getOpcode() == ISD::AND) {
if (auto *AndRHS = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
EVT ShiftTy =
getShiftAmountTy(ShValTy, Layout, !DCI.isBeforeLegalize());
if (Cond == ISD::SETNE && C1 == 0) {// (X & 8) != 0 --> (X & 8) >> 3
// Perform the xform if the AND RHS is a single bit.
unsigned ShCt = AndRHS->getAPIntValue().logBase2();
if (AndRHS->getAPIntValue().isPowerOf2() &&
!TLI.shouldAvoidTransformToShift(ShValTy, ShCt)) {
return DAG.getNode(ISD::TRUNCATE, dl, VT,
DAG.getNode(ISD::SRL, dl, ShValTy, N0,
DAG.getConstant(ShCt, dl, ShiftTy)));
return DAG.getNode(
ISD::TRUNCATE, dl, VT,
DAG.getNode(ISD::SRL, dl, ShValTy, N0,
DAG.getShiftAmountConstant(
ShCt, ShValTy, dl, !DCI.isBeforeLegalize())));
}
} else if (Cond == ISD::SETEQ && C1 == AndRHS->getAPIntValue()) {
// (X & 8) == 8 --> (X & 8) >> 3
// Perform the xform if C1 is a single bit.
unsigned ShCt = C1.logBase2();
if (C1.isPowerOf2() &&
!TLI.shouldAvoidTransformToShift(ShValTy, ShCt)) {
return DAG.getNode(ISD::TRUNCATE, dl, VT,
DAG.getNode(ISD::SRL, dl, ShValTy, N0,
DAG.getConstant(ShCt, dl, ShiftTy)));
return DAG.getNode(
ISD::TRUNCATE, dl, VT,
DAG.getNode(ISD::SRL, dl, ShValTy, N0,
DAG.getShiftAmountConstant(
ShCt, ShValTy, dl, !DCI.isBeforeLegalize())));
}
}
}
}

if (C1.getSignificantBits() <= 64 &&
!isLegalICmpImmediate(C1.getSExtValue())) {
EVT ShiftTy = getShiftAmountTy(ShValTy, Layout, !DCI.isBeforeLegalize());
// (X & -256) == 256 -> (X >> 8) == 1
if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
N0.getOpcode() == ISD::AND && N0.hasOneUse()) {
Expand All @@ -5074,9 +5071,10 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
if (AndRHSC.isNegatedPowerOf2() && (AndRHSC & C1) == C1) {
unsigned ShiftBits = AndRHSC.countr_zero();
if (!TLI.shouldAvoidTransformToShift(ShValTy, ShiftBits)) {
SDValue Shift =
DAG.getNode(ISD::SRL, dl, ShValTy, N0.getOperand(0),
DAG.getConstant(ShiftBits, dl, ShiftTy));
SDValue Shift = DAG.getNode(
ISD::SRL, dl, ShValTy, N0.getOperand(0),
DAG.getShiftAmountConstant(ShiftBits, ShValTy, dl,
!DCI.isBeforeLegalize()));
SDValue CmpRHS = DAG.getConstant(C1.lshr(ShiftBits), dl, ShValTy);
return DAG.getSetCC(dl, VT, Shift, CmpRHS, Cond);
}
Expand All @@ -5103,8 +5101,10 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
if (ShiftBits && NewC.getSignificantBits() <= 64 &&
isLegalICmpImmediate(NewC.getSExtValue()) &&
!TLI.shouldAvoidTransformToShift(ShValTy, ShiftBits)) {
SDValue Shift = DAG.getNode(ISD::SRL, dl, ShValTy, N0,
DAG.getConstant(ShiftBits, dl, ShiftTy));
SDValue Shift =
DAG.getNode(ISD::SRL, dl, ShValTy, N0,
DAG.getShiftAmountConstant(ShiftBits, ShValTy, dl,
!DCI.isBeforeLegalize()));
SDValue CmpRHS = DAG.getConstant(NewC, dl, ShValTy);
return DAG.getSetCC(dl, VT, Shift, CmpRHS, NewCond);
}
Expand Down Expand Up @@ -8944,7 +8944,6 @@ SDValue TargetLowering::expandABS(SDNode *N, SelectionDAG &DAG,
bool IsNegative) const {
SDLoc dl(N);
EVT VT = N->getValueType(0);
EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
SDValue Op = N->getOperand(0);

// abs(x) -> smax(x,sub(0,x))
Expand Down Expand Up @@ -8982,9 +8981,9 @@ SDValue TargetLowering::expandABS(SDNode *N, SelectionDAG &DAG,
return SDValue();

Op = DAG.getFreeze(Op);
SDValue Shift =
DAG.getNode(ISD::SRA, dl, VT, Op,
DAG.getConstant(VT.getScalarSizeInBits() - 1, dl, ShVT));
SDValue Shift = DAG.getNode(
ISD::SRA, dl, VT, Op,
DAG.getShiftAmountConstant(VT.getScalarSizeInBits() - 1, VT, dl));
SDValue Xor = DAG.getNode(ISD::XOR, dl, VT, Op, Shift);

// abs(x) -> Y = sra (X, size(X)-1); sub (xor (X, Y), Y)
Expand Down Expand Up @@ -9592,9 +9591,7 @@ TargetLowering::expandUnalignedLoad(LoadSDNode *LD, SelectionDAG &DAG) const {
}

// aggregate the two parts
SDValue ShiftAmount =
DAG.getConstant(NumBits, dl, getShiftAmountTy(Hi.getValueType(),
DAG.getDataLayout()));
SDValue ShiftAmount = DAG.getShiftAmountConstant(NumBits, VT, dl);
SDValue Result = DAG.getNode(ISD::SHL, dl, VT, Hi, ShiftAmount);
Result = DAG.getNode(ISD::OR, dl, VT, Result, Lo);

Expand Down Expand Up @@ -9706,8 +9703,8 @@ SDValue TargetLowering::expandUnalignedStore(StoreSDNode *ST,
unsigned IncrementSize = NumBits / 8;

// Divide the stored value in two parts.
SDValue ShiftAmount = DAG.getConstant(
NumBits, dl, getShiftAmountTy(Val.getValueType(), DAG.getDataLayout()));
SDValue ShiftAmount =
DAG.getShiftAmountConstant(NumBits, Val.getValueType(), dl);
SDValue Lo = Val;
// If Val is a constant, replace the upper bits with 0. The SRL will constant
// fold and not use the upper bits. A smaller constant may be easier to
Expand Down Expand Up @@ -10351,9 +10348,8 @@ TargetLowering::expandFixedPointMul(SDNode *Node, SelectionDAG &DAG) const {
// The result will need to be shifted right by the scale since both operands
// are scaled. The result is given to us in 2 halves, so we only want part of
// both in the result.
EVT ShiftTy = getShiftAmountTy(VT, DAG.getDataLayout());
SDValue Result = DAG.getNode(ISD::FSHR, dl, VT, Hi, Lo,
DAG.getConstant(Scale, dl, ShiftTy));
DAG.getShiftAmountConstant(Scale, VT, dl));
if (!Saturating)
return Result;

Expand Down Expand Up @@ -10381,7 +10377,7 @@ TargetLowering::expandFixedPointMul(SDNode *Node, SelectionDAG &DAG) const {

if (Scale == 0) {
SDValue Sign = DAG.getNode(ISD::SRA, dl, VT, Lo,
DAG.getConstant(VTSize - 1, dl, ShiftTy));
DAG.getShiftAmountConstant(VTSize - 1, VT, dl));
SDValue Overflow = DAG.getSetCC(dl, BoolVT, Hi, Sign, ISD::SETNE);
// Saturated to SatMin if wide product is negative, and SatMax if wide
// product is positive ...
Expand Down Expand Up @@ -10448,13 +10444,12 @@ TargetLowering::expandFixedPointDiv(unsigned Opcode, const SDLoc &dl,
// RHS down by RHSShift, we can emit a regular division with a final scaling
// factor of Scale.

EVT ShiftTy = getShiftAmountTy(VT, DAG.getDataLayout());
if (LHSShift)
LHS = DAG.getNode(ISD::SHL, dl, VT, LHS,
DAG.getConstant(LHSShift, dl, ShiftTy));
DAG.getShiftAmountConstant(LHSShift, VT, dl));
if (RHSShift)
RHS = DAG.getNode(Signed ? ISD::SRA : ISD::SRL, dl, VT, RHS,
DAG.getConstant(RHSShift, dl, ShiftTy));
DAG.getShiftAmountConstant(RHSShift, VT, dl));

SDValue Quot;
if (Signed) {
Expand Down Expand Up @@ -10597,8 +10592,7 @@ bool TargetLowering::expandMULO(SDNode *Node, SDValue &Result,
if (C.isPowerOf2()) {
// smulo(x, signed_min) is same as umulo(x, signed_min).
bool UseArithShift = isSigned && !C.isMinSignedValue();
EVT ShiftAmtTy = getShiftAmountTy(VT, DAG.getDataLayout());
SDValue ShiftAmt = DAG.getConstant(C.logBase2(), dl, ShiftAmtTy);
SDValue ShiftAmt = DAG.getShiftAmountConstant(C.logBase2(), VT, dl);
Result = DAG.getNode(ISD::SHL, dl, VT, LHS, ShiftAmt);
Overflow = DAG.getSetCC(dl, SetCCVT,
DAG.getNode(UseArithShift ? ISD::SRA : ISD::SRL,
Expand Down Expand Up @@ -10630,8 +10624,8 @@ bool TargetLowering::expandMULO(SDNode *Node, SDValue &Result,
RHS = DAG.getNode(Ops[isSigned][2], dl, WideVT, RHS);
SDValue Mul = DAG.getNode(ISD::MUL, dl, WideVT, LHS, RHS);
BottomHalf = DAG.getNode(ISD::TRUNCATE, dl, VT, Mul);
SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits(), dl,
getShiftAmountTy(WideVT, DAG.getDataLayout()));
SDValue ShiftAmt =
DAG.getShiftAmountConstant(VT.getScalarSizeInBits(), WideVT, dl);
TopHalf = DAG.getNode(ISD::TRUNCATE, dl, VT,
DAG.getNode(ISD::SRL, dl, WideVT, Mul, ShiftAmt));
} else {
Expand All @@ -10643,9 +10637,8 @@ bool TargetLowering::expandMULO(SDNode *Node, SDValue &Result,

Result = BottomHalf;
if (isSigned) {
SDValue ShiftAmt = DAG.getConstant(
VT.getScalarSizeInBits() - 1, dl,
getShiftAmountTy(BottomHalf.getValueType(), DAG.getDataLayout()));
SDValue ShiftAmt = DAG.getShiftAmountConstant(
VT.getScalarSizeInBits() - 1, BottomHalf.getValueType(), dl);
SDValue Sign = DAG.getNode(ISD::SRA, dl, VT, BottomHalf, ShiftAmt);
Overflow = DAG.getSetCC(dl, SetCCVT, TopHalf, Sign, ISD::SETNE);
} else {
Expand Down