Skip to content

[SelectionDAG] Use getShiftAmountConstant to simplify code. NFC #80561

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 5, 2024

Conversation

topperc
Copy link
Collaborator

@topperc topperc commented Feb 3, 2024

Replace calls to getShiftAmountTy+getConstant with getShiftAmountContant.

Replace calls to getShiftAmountTy+getConstant with getShiftAmountContant.
@topperc topperc requested review from arsenm and RKSimon February 3, 2024 20:12
@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label Feb 3, 2024
@llvmbot
Copy link
Member

llvmbot commented Feb 3, 2024

@llvm/pr-subscribers-llvm-selectiondag

Author: Craig Topper (topperc)

Changes

Replace calls to getShiftAmountTy+getConstant with getShiftAmountContant.


Full diff: https://github.com/llvm/llvm-project/pull/80561.diff

1 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+44-51)
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 03b2a66989bd4..b15f62bc3aae7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -1096,7 +1096,6 @@ bool TargetLowering::SimplifyDemandedBits(
   APInt DemandedBits = OriginalDemandedBits;
   APInt DemandedElts = OriginalDemandedElts;
   SDLoc dl(Op);
-  auto &DL = TLO.DAG.getDataLayout();
 
   // Undef operand.
   if (Op.isUndef())
@@ -2288,9 +2287,8 @@ bool TargetLowering::SimplifyDemandedBits(
       // the right place.
       unsigned ShiftOpcode = NLZ > NTZ ? ISD::SRL : ISD::SHL;
       if (!TLO.LegalOperations() || isOperationLegal(ShiftOpcode, VT)) {
-        EVT ShiftAmtTy = getShiftAmountTy(VT, DL);
         unsigned ShiftAmount = NLZ > NTZ ? NLZ - NTZ : NTZ - NLZ;
-        SDValue ShAmt = TLO.DAG.getConstant(ShiftAmount, dl, ShiftAmtTy);
+        SDValue ShAmt = TLO.DAG.getShiftAmountConstant(ShiftAmount, VT, dl);
         SDValue NewOp = TLO.DAG.getNode(ShiftOpcode, dl, VT, Src, ShAmt);
         return TLO.CombineTo(Op, NewOp);
       }
@@ -2330,8 +2328,8 @@ bool TargetLowering::SimplifyDemandedBits(
       if (!AlreadySignExtended) {
         // Compute the correct shift amount type, which must be getShiftAmountTy
         // for scalar types after legalization.
-        SDValue ShiftAmt = TLO.DAG.getConstant(BitWidth - ExVTBits, dl,
-                                               getShiftAmountTy(VT, DL));
+        SDValue ShiftAmt =
+            TLO.DAG.getShiftAmountConstant(BitWidth - ExVTBits, VT, dl);
         return TLO.CombineTo(Op,
                              TLO.DAG.getNode(ISD::SHL, dl, VT, Op0, ShiftAmt));
       }
@@ -2574,8 +2572,8 @@ bool TargetLowering::SimplifyDemandedBits(
         if (!(HighBits & DemandedBits)) {
           // None of the shifted in bits are needed.  Add a truncate of the
           // shift input, then shift it.
-          SDValue NewShAmt = TLO.DAG.getConstant(
-              ShVal, dl, getShiftAmountTy(VT, DL, TLO.LegalTypes()));
+          SDValue NewShAmt =
+              TLO.DAG.getShiftAmountConstant(ShVal, VT, dl, TLO.LegalTypes());
           SDValue NewTrunc =
               TLO.DAG.getNode(ISD::TRUNCATE, dl, VT, Src.getOperand(0));
           return TLO.CombineTo(
@@ -2753,8 +2751,7 @@ bool TargetLowering::SimplifyDemandedBits(
       unsigned CTZ = DemandedBits.countr_zero();
       ConstantSDNode *C = isConstOrConstSplat(Op.getOperand(1), DemandedElts);
       if (C && C->getAPIntValue().countr_zero() == CTZ) {
-        EVT ShiftAmtTy = getShiftAmountTy(VT, TLO.DAG.getDataLayout());
-        SDValue AmtC = TLO.DAG.getConstant(CTZ, dl, ShiftAmtTy);
+        SDValue AmtC = TLO.DAG.getShiftAmountConstant(CTZ, VT, dl);
         SDValue Shl = TLO.DAG.getNode(ISD::SHL, dl, VT, Op.getOperand(0), AmtC);
         return TLO.CombineTo(Op, Shl);
       }
@@ -2852,9 +2849,9 @@ bool TargetLowering::SimplifyDemandedBits(
       return 0;
     };
 
-    auto foldMul = [&](ISD::NodeType NT, SDValue X, SDValue Y, unsigned ShlAmt) {
-      EVT ShiftAmtTy = getShiftAmountTy(VT, TLO.DAG.getDataLayout());
-      SDValue ShlAmtC = TLO.DAG.getConstant(ShlAmt, dl, ShiftAmtTy);
+    auto foldMul = [&](ISD::NodeType NT, SDValue X, SDValue Y,
+                       unsigned ShlAmt) {
+      SDValue ShlAmtC = TLO.DAG.getShiftAmountConstant(ShlAmt, VT, dl);
       SDValue Shl = TLO.DAG.getNode(ISD::SHL, dl, VT, X, ShlAmtC);
       SDValue Res = TLO.DAG.getNode(NT, dl, VT, Y, Shl);
       return TLO.CombineTo(Op, Res);
@@ -4204,9 +4201,8 @@ SDValue TargetLowering::foldSetCCWithBinOp(EVT VT, SDValue N0, SDValue N1,
     return SDValue();
 
   // (X - Y) == Y --> X == Y << 1
-  EVT ShiftVT = getShiftAmountTy(OpVT, DAG.getDataLayout(),
-                                 !DCI.isBeforeLegalize());
-  SDValue One = DAG.getConstant(1, DL, ShiftVT);
+  SDValue One =
+      DAG.getShiftAmountConstant(1, OpVT, DL, !DCI.isBeforeLegalize());
   SDValue YShl1 = DAG.getNode(ISD::SHL, DL, N1.getValueType(), Y, One);
   if (!DCI.isCalledByLegalizer())
     DCI.AddToWorklist(YShl1.getNode());
@@ -5038,16 +5034,16 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
         (VT == ShValTy || (isTypeLegal(VT) && VT.bitsLE(ShValTy))) &&
         N0.getOpcode() == ISD::AND) {
       if (auto *AndRHS = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
-        EVT ShiftTy =
-            getShiftAmountTy(ShValTy, Layout, !DCI.isBeforeLegalize());
         if (Cond == ISD::SETNE && C1 == 0) {// (X & 8) != 0  -->  (X & 8) >> 3
           // Perform the xform if the AND RHS is a single bit.
           unsigned ShCt = AndRHS->getAPIntValue().logBase2();
           if (AndRHS->getAPIntValue().isPowerOf2() &&
               !TLI.shouldAvoidTransformToShift(ShValTy, ShCt)) {
-            return DAG.getNode(ISD::TRUNCATE, dl, VT,
-                               DAG.getNode(ISD::SRL, dl, ShValTy, N0,
-                                           DAG.getConstant(ShCt, dl, ShiftTy)));
+            return DAG.getNode(
+                ISD::TRUNCATE, dl, VT,
+                DAG.getNode(ISD::SRL, dl, ShValTy, N0,
+                            DAG.getShiftAmountConstant(
+                                ShCt, ShValTy, dl, !DCI.isBeforeLegalize())));
           }
         } else if (Cond == ISD::SETEQ && C1 == AndRHS->getAPIntValue()) {
           // (X & 8) == 8  -->  (X & 8) >> 3
@@ -5055,9 +5051,11 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
           unsigned ShCt = C1.logBase2();
           if (C1.isPowerOf2() &&
               !TLI.shouldAvoidTransformToShift(ShValTy, ShCt)) {
-            return DAG.getNode(ISD::TRUNCATE, dl, VT,
-                               DAG.getNode(ISD::SRL, dl, ShValTy, N0,
-                                           DAG.getConstant(ShCt, dl, ShiftTy)));
+            return DAG.getNode(
+                ISD::TRUNCATE, dl, VT,
+                DAG.getNode(ISD::SRL, dl, ShValTy, N0,
+                            DAG.getShiftAmountConstant(
+                                ShCt, ShValTy, dl, !DCI.isBeforeLegalize())));
           }
         }
       }
@@ -5065,7 +5063,6 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
 
     if (C1.getSignificantBits() <= 64 &&
         !isLegalICmpImmediate(C1.getSExtValue())) {
-      EVT ShiftTy = getShiftAmountTy(ShValTy, Layout, !DCI.isBeforeLegalize());
       // (X & -256) == 256 -> (X >> 8) == 1
       if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
           N0.getOpcode() == ISD::AND && N0.hasOneUse()) {
@@ -5074,9 +5071,10 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
           if (AndRHSC.isNegatedPowerOf2() && (AndRHSC & C1) == C1) {
             unsigned ShiftBits = AndRHSC.countr_zero();
             if (!TLI.shouldAvoidTransformToShift(ShValTy, ShiftBits)) {
-              SDValue Shift =
-                DAG.getNode(ISD::SRL, dl, ShValTy, N0.getOperand(0),
-                            DAG.getConstant(ShiftBits, dl, ShiftTy));
+              SDValue Shift = DAG.getNode(
+                  ISD::SRL, dl, ShValTy, N0.getOperand(0),
+                  DAG.getShiftAmountConstant(ShiftBits, ShValTy, dl,
+                                             !DCI.isBeforeLegalize()));
               SDValue CmpRHS = DAG.getConstant(C1.lshr(ShiftBits), dl, ShValTy);
               return DAG.getSetCC(dl, VT, Shift, CmpRHS, Cond);
             }
@@ -5103,8 +5101,10 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
         if (ShiftBits && NewC.getSignificantBits() <= 64 &&
             isLegalICmpImmediate(NewC.getSExtValue()) &&
             !TLI.shouldAvoidTransformToShift(ShValTy, ShiftBits)) {
-          SDValue Shift = DAG.getNode(ISD::SRL, dl, ShValTy, N0,
-                                      DAG.getConstant(ShiftBits, dl, ShiftTy));
+          SDValue Shift =
+              DAG.getNode(ISD::SRL, dl, ShValTy, N0,
+                          DAG.getShiftAmountConstant(ShiftBits, ShValTy, dl,
+                                                     !DCI.isBeforeLegalize()));
           SDValue CmpRHS = DAG.getConstant(NewC, dl, ShValTy);
           return DAG.getSetCC(dl, VT, Shift, CmpRHS, NewCond);
         }
@@ -8944,7 +8944,6 @@ SDValue TargetLowering::expandABS(SDNode *N, SelectionDAG &DAG,
                                   bool IsNegative) const {
   SDLoc dl(N);
   EVT VT = N->getValueType(0);
-  EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
   SDValue Op = N->getOperand(0);
 
   // abs(x) -> smax(x,sub(0,x))
@@ -8982,9 +8981,9 @@ SDValue TargetLowering::expandABS(SDNode *N, SelectionDAG &DAG,
     return SDValue();
 
   Op = DAG.getFreeze(Op);
-  SDValue Shift =
-      DAG.getNode(ISD::SRA, dl, VT, Op,
-                  DAG.getConstant(VT.getScalarSizeInBits() - 1, dl, ShVT));
+  SDValue Shift = DAG.getNode(
+      ISD::SRA, dl, VT, Op,
+      DAG.getShiftAmountConstant(VT.getScalarSizeInBits() - 1, VT, dl));
   SDValue Xor = DAG.getNode(ISD::XOR, dl, VT, Op, Shift);
 
   // abs(x) -> Y = sra (X, size(X)-1); sub (xor (X, Y), Y)
@@ -9592,9 +9591,7 @@ TargetLowering::expandUnalignedLoad(LoadSDNode *LD, SelectionDAG &DAG) const {
   }
 
   // aggregate the two parts
-  SDValue ShiftAmount =
-      DAG.getConstant(NumBits, dl, getShiftAmountTy(Hi.getValueType(),
-                                                    DAG.getDataLayout()));
+  SDValue ShiftAmount = DAG.getShiftAmountConstant(NumBits, VT, dl);
   SDValue Result = DAG.getNode(ISD::SHL, dl, VT, Hi, ShiftAmount);
   Result = DAG.getNode(ISD::OR, dl, VT, Result, Lo);
 
@@ -9706,8 +9703,8 @@ SDValue TargetLowering::expandUnalignedStore(StoreSDNode *ST,
   unsigned IncrementSize = NumBits / 8;
 
   // Divide the stored value in two parts.
-  SDValue ShiftAmount = DAG.getConstant(
-      NumBits, dl, getShiftAmountTy(Val.getValueType(), DAG.getDataLayout()));
+  SDValue ShiftAmount =
+      DAG.getShiftAmountConstant(NumBits, Val.getValueType(), dl);
   SDValue Lo = Val;
   // If Val is a constant, replace the upper bits with 0. The SRL will constant
   // fold and not use the upper bits. A smaller constant may be easier to
@@ -10351,9 +10348,8 @@ TargetLowering::expandFixedPointMul(SDNode *Node, SelectionDAG &DAG) const {
   // The result will need to be shifted right by the scale since both operands
   // are scaled. The result is given to us in 2 halves, so we only want part of
   // both in the result.
-  EVT ShiftTy = getShiftAmountTy(VT, DAG.getDataLayout());
   SDValue Result = DAG.getNode(ISD::FSHR, dl, VT, Hi, Lo,
-                               DAG.getConstant(Scale, dl, ShiftTy));
+                               DAG.getShiftAmountConstant(Scale, VT, dl));
   if (!Saturating)
     return Result;
 
@@ -10381,7 +10377,7 @@ TargetLowering::expandFixedPointMul(SDNode *Node, SelectionDAG &DAG) const {
 
   if (Scale == 0) {
     SDValue Sign = DAG.getNode(ISD::SRA, dl, VT, Lo,
-                               DAG.getConstant(VTSize - 1, dl, ShiftTy));
+                               DAG.getShiftAmountConstant(VTSize - 1, VT, dl));
     SDValue Overflow = DAG.getSetCC(dl, BoolVT, Hi, Sign, ISD::SETNE);
     // Saturated to SatMin if wide product is negative, and SatMax if wide
     // product is positive ...
@@ -10448,13 +10444,12 @@ TargetLowering::expandFixedPointDiv(unsigned Opcode, const SDLoc &dl,
   // RHS down by RHSShift, we can emit a regular division with a final scaling
   // factor of Scale.
 
-  EVT ShiftTy = getShiftAmountTy(VT, DAG.getDataLayout());
   if (LHSShift)
     LHS = DAG.getNode(ISD::SHL, dl, VT, LHS,
-                      DAG.getConstant(LHSShift, dl, ShiftTy));
+                      DAG.getShiftAmountConstant(LHSShift, VT, dl));
   if (RHSShift)
     RHS = DAG.getNode(Signed ? ISD::SRA : ISD::SRL, dl, VT, RHS,
-                      DAG.getConstant(RHSShift, dl, ShiftTy));
+                      DAG.getShiftAmountConstant(RHSShift, VT, dl));
 
   SDValue Quot;
   if (Signed) {
@@ -10597,8 +10592,7 @@ bool TargetLowering::expandMULO(SDNode *Node, SDValue &Result,
     if (C.isPowerOf2()) {
       // smulo(x, signed_min) is same as umulo(x, signed_min).
       bool UseArithShift = isSigned && !C.isMinSignedValue();
-      EVT ShiftAmtTy = getShiftAmountTy(VT, DAG.getDataLayout());
-      SDValue ShiftAmt = DAG.getConstant(C.logBase2(), dl, ShiftAmtTy);
+      SDValue ShiftAmt = DAG.getShiftAmountConstant(C.logBase2(), VT, dl);
       Result = DAG.getNode(ISD::SHL, dl, VT, LHS, ShiftAmt);
       Overflow = DAG.getSetCC(dl, SetCCVT,
           DAG.getNode(UseArithShift ? ISD::SRA : ISD::SRL,
@@ -10630,8 +10624,8 @@ bool TargetLowering::expandMULO(SDNode *Node, SDValue &Result,
     RHS = DAG.getNode(Ops[isSigned][2], dl, WideVT, RHS);
     SDValue Mul = DAG.getNode(ISD::MUL, dl, WideVT, LHS, RHS);
     BottomHalf = DAG.getNode(ISD::TRUNCATE, dl, VT, Mul);
-    SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits(), dl,
-        getShiftAmountTy(WideVT, DAG.getDataLayout()));
+    SDValue ShiftAmt =
+        DAG.getShiftAmountConstant(VT.getScalarSizeInBits(), WideVT, dl);
     TopHalf = DAG.getNode(ISD::TRUNCATE, dl, VT,
                           DAG.getNode(ISD::SRL, dl, WideVT, Mul, ShiftAmt));
   } else {
@@ -10643,9 +10637,8 @@ bool TargetLowering::expandMULO(SDNode *Node, SDValue &Result,
 
   Result = BottomHalf;
   if (isSigned) {
-    SDValue ShiftAmt = DAG.getConstant(
-        VT.getScalarSizeInBits() - 1, dl,
-        getShiftAmountTy(BottomHalf.getValueType(), DAG.getDataLayout()));
+    SDValue ShiftAmt = DAG.getShiftAmountConstant(
+        VT.getScalarSizeInBits() - 1, BottomHalf.getValueType(), dl);
     SDValue Sign = DAG.getNode(ISD::SRA, dl, VT, BottomHalf, ShiftAmt);
     Overflow = DAG.getSetCC(dl, SetCCVT, TopHalf, Sign, ISD::SETNE);
   } else {

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@topperc topperc merged commit f72da9f into llvm:main Feb 5, 2024
@topperc topperc deleted the pr/shiftamountconstant branch February 5, 2024 00:05
agozillon pushed a commit to agozillon/llvm-project that referenced this pull request Feb 5, 2024
…#80561)

Replace calls to getShiftAmountTy+getConstant with
getShiftAmountContant.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants