@@ -13857,6 +13857,52 @@ static SDValue getMad64_32(SelectionDAG &DAG, const SDLoc &SL, EVT VT,
13857
13857
return DAG.getNode(ISD::TRUNCATE, SL, VT, Mad);
13858
13858
}
13859
13859
13860
+ // Fold
13861
+ // y = lshr i64 x, 32
13862
+ // res = add (mul i64 y, Constant), x where "Constant" is a 32 bit
13863
+ // negative value
13864
+ // To
13865
+ // res = mad_u64_u32 y.lo ,Constant.lo, x.lo
13866
+ static SDValue tryFoldMADwithSRL(SelectionDAG &DAG, const SDLoc &SL,
13867
+ SDValue MulLHS, SDValue MulRHS,
13868
+ SDValue AddRHS) {
13869
+
13870
+ if (MulLHS.getValueType() != MVT::i64)
13871
+ return SDValue();
13872
+
13873
+ ConstantSDNode *ConstOp;
13874
+ SDValue ShiftOp;
13875
+ if (MulLHS.getOpcode() == ISD::SRL && MulRHS.getOpcode() == ISD::Constant) {
13876
+ ConstOp = cast<ConstantSDNode>(MulRHS.getNode());
13877
+ ShiftOp = MulLHS;
13878
+ } else if (MulRHS.getOpcode() == ISD::SRL &&
13879
+ MulLHS.getOpcode() == ISD::Constant) {
13880
+ ConstOp = cast<ConstantSDNode>(MulLHS.getNode());
13881
+ ShiftOp = MulRHS;
13882
+ } else
13883
+ return SDValue();
13884
+
13885
+ if (ShiftOp.getOperand(1).getOpcode() != ISD::Constant ||
13886
+ AddRHS != ShiftOp.getOperand(0))
13887
+ return SDValue();
13888
+
13889
+ if (cast<ConstantSDNode>(ShiftOp->getOperand(1))->getAsZExtVal() != 32)
13890
+ return SDValue();
13891
+
13892
+ APInt ConstVal = ConstOp->getAPIntValue();
13893
+ if (!ConstVal.isNegative() || !ConstVal.isSignedIntN(33))
13894
+ return SDValue();
13895
+
13896
+ SDValue Zero = DAG.getConstant(0, SL, MVT::i32);
13897
+ SDValue ConstMul = DAG.getConstant(
13898
+ ConstVal.getZExtValue() & 0x00000000FFFFFFFF, SL, MVT::i32);
13899
+ AddRHS = DAG.getNode(ISD::AND, SL, MVT::i64, AddRHS,
13900
+ DAG.getConstant(0x00000000FFFFFFFF, SL, MVT::i64));
13901
+ return getMad64_32(DAG, SL, MVT::i64,
13902
+ DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, MulLHS), ConstMul,
13903
+ AddRHS, false);
13904
+ }
13905
+
13860
13906
// Fold (add (mul x, y), z) --> (mad_[iu]64_[iu]32 x, y, z) plus high
13861
13907
// multiplies, if any.
13862
13908
//
@@ -13915,6 +13961,9 @@ SDValue SITargetLowering::tryFoldToMad64_32(SDNode *N,
13915
13961
SDValue MulRHS = LHS.getOperand(1);
13916
13962
SDValue AddRHS = RHS;
13917
13963
13964
+ if (SDValue FoldedMAD = tryFoldMADwithSRL(DAG, SL, MulLHS, MulRHS, AddRHS))
13965
+ return FoldedMAD;
13966
+
13918
13967
// Always check whether operands are small unsigned values, since that
13919
13968
// knowledge is useful in more cases. Check for small signed values only if
13920
13969
// doing so can unlock a shorter code sequence.
0 commit comments