@@ -13870,24 +13870,22 @@ static SDValue tryFoldMADwithSRL(SelectionDAG &DAG, const SDLoc &SL,
13870
13870
if (MulLHS.getValueType() != MVT::i64 || MulLHS.getOpcode() != ISD::SRL)
13871
13871
return SDValue();
13872
13872
13873
- if (MulLHS.getOperand(1).getOpcode() != ISD::Constant ||
13874
- MulLHS.getOperand(0) != AddRHS)
13873
+ ConstantSDNode *ShiftVal = dyn_cast<ConstantSDNode> (MulLHS.getOperand(1));
13874
+ if (!ShiftVal || MulLHS.getOperand(0) != AddRHS)
13875
13875
return SDValue();
13876
13876
13877
- if (cast<ConstantSDNode>(MulLHS->getOperand(1)) ->getAsZExtVal() != 32)
13877
+ if (ShiftVal ->getAsZExtVal() != 32)
13878
13878
return SDValue();
13879
13879
13880
- APInt Const = cast <ConstantSDNode>(MulRHS.getNode())->getAPIntValue();
13880
+ APInt Const = dyn_cast <ConstantSDNode>(MulRHS.getNode())->getAPIntValue();
13881
13881
if (!Const.isNegative() || !Const.isSignedIntN(33))
13882
13882
return SDValue();
13883
13883
13884
13884
SDValue ConstMul =
13885
13885
DAG.getConstant(Const.getZExtValue() & 0x00000000FFFFFFFF, SL, MVT::i32);
13886
- AddRHS = DAG.getNode(ISD::AND, SL, MVT::i64, AddRHS,
13887
- DAG.getConstant(0x00000000FFFFFFFF, SL, MVT::i64));
13888
13886
return getMad64_32(DAG, SL, MVT::i64,
13889
13887
DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, MulLHS), ConstMul,
13890
- AddRHS, false);
13888
+ DAG.getZeroExtendInReg( AddRHS, SL, MVT::i32) , false);
13891
13889
}
13892
13890
13893
13891
// Fold (add (mul x, y), z) --> (mad_[iu]64_[iu]32 x, y, z) plus high
@@ -13948,8 +13946,7 @@ SDValue SITargetLowering::tryFoldToMad64_32(SDNode *N,
13948
13946
SDValue MulRHS = LHS.getOperand(1);
13949
13947
SDValue AddRHS = RHS;
13950
13948
13951
- if (MulLHS.getOpcode() == ISD::Constant ||
13952
- MulRHS.getOpcode() == ISD::Constant) {
13949
+ if (isa<ConstantSDNode>(MulLHS) || isa<ConstantSDNode>(MulRHS)) {
13953
13950
if (MulRHS.getOpcode() == ISD::SRL)
13954
13951
std::swap(MulLHS, MulRHS);
13955
13952
0 commit comments