-
Notifications
You must be signed in to change notification settings - Fork 14.2k
[DAGCombiner][VP] Add DAGCombine for VP_MUL #80105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -438,7 +438,7 @@ namespace { | |
SDValue visitSUBE(SDNode *N); | ||
SDValue visitUSUBO_CARRY(SDNode *N); | ||
SDValue visitSSUBO_CARRY(SDNode *N); | ||
SDValue visitMUL(SDNode *N); | ||
template <class MatchContextClass> SDValue visitMUL(SDNode *N); | ||
SDValue visitMULFIX(SDNode *N); | ||
SDValue useDivRem(SDNode *N); | ||
SDValue visitSDIV(SDNode *N); | ||
|
@@ -1855,7 +1855,7 @@ SDValue DAGCombiner::visit(SDNode *N) { | |
case ISD::SMULFIXSAT: | ||
case ISD::UMULFIX: | ||
case ISD::UMULFIXSAT: return visitMULFIX(N); | ||
case ISD::MUL: return visitMUL(N); | ||
case ISD::MUL: return visitMUL<EmptyMatchContext>(N); | ||
case ISD::SDIV: return visitSDIV(N); | ||
case ISD::UDIV: return visitUDIV(N); | ||
case ISD::SREM: | ||
|
@@ -4331,11 +4331,13 @@ SDValue DAGCombiner::visitMULFIX(SDNode *N) { | |
return SDValue(); | ||
} | ||
|
||
SDValue DAGCombiner::visitMUL(SDNode *N) { | ||
template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) { | ||
SDValue N0 = N->getOperand(0); | ||
SDValue N1 = N->getOperand(1); | ||
EVT VT = N0.getValueType(); | ||
SDLoc DL(N); | ||
bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>; | ||
MatchContextClass Matcher(DAG, TLI, N); | ||
|
||
// fold (mul x, undef) -> 0 | ||
if (N0.isUndef() || N1.isUndef()) | ||
|
@@ -4348,16 +4350,18 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { | |
// canonicalize constant to RHS (vector doesn't have to splat) | ||
if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && | ||
!DAG.isConstantIntBuildVectorOrConstantInt(N1)) | ||
return DAG.getNode(ISD::MUL, DL, VT, N1, N0); | ||
return Matcher.getNode(ISD::MUL, DL, VT, N1, N0); | ||
|
||
bool N1IsConst = false; | ||
bool N1IsOpaqueConst = false; | ||
APInt ConstValue1; | ||
|
||
// fold vector ops | ||
if (VT.isVector()) { | ||
if (SDValue FoldedVOp = SimplifyVBinOp(N, DL)) | ||
return FoldedVOp; | ||
// TODO: Change this to use SimplifyVBinOp when it supports VP op. | ||
if (!UseVP) | ||
if (SDValue FoldedVOp = SimplifyVBinOp(N, DL)) | ||
return FoldedVOp; | ||
|
||
sunshaoce marked this conversation as resolved.
Show resolved
Hide resolved
|
||
N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1); | ||
sunshaoce marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert((!N1IsConst || | ||
|
@@ -4379,20 +4383,21 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { | |
if (N1IsConst && ConstValue1.isOne()) | ||
return N0; | ||
|
||
if (SDValue NewSel = foldBinOpIntoSelect(N)) | ||
return NewSel; | ||
if (!UseVP) | ||
if (SDValue NewSel = foldBinOpIntoSelect(N)) | ||
return NewSel; | ||
|
||
// fold (mul x, -1) -> 0-x | ||
if (N1IsConst && ConstValue1.isAllOnes()) | ||
return DAG.getNegative(N0, DL, VT); | ||
return Matcher.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0); | ||
|
||
// fold (mul x, (1 << c)) -> x << c | ||
if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) && | ||
(!VT.isVector() || Level <= AfterLegalizeVectorOps)) { | ||
if (SDValue LogBase2 = BuildLogBase2(N1, DL)) { | ||
EVT ShiftVT = getShiftAmountTy(N0.getValueType()); | ||
SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT); | ||
return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc); | ||
return Matcher.getNode(ISD::SHL, DL, VT, N0, Trunc); | ||
} | ||
} | ||
|
||
|
@@ -4403,24 +4408,26 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { | |
|
||
// FIXME: If the input is something that is easily negated (e.g. a | ||
// single-use add), we should put the negate there. | ||
return DAG.getNode(ISD::SUB, DL, VT, | ||
DAG.getConstant(0, DL, VT), | ||
DAG.getNode(ISD::SHL, DL, VT, N0, | ||
DAG.getConstant(Log2Val, DL, ShiftVT))); | ||
return Matcher.getNode( | ||
ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure about the correctness, but this patch has to be rebase first. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is equivalent to no change because of |
||
Matcher.getNode(ISD::SHL, DL, VT, N0, | ||
DAG.getConstant(Log2Val, DL, ShiftVT))); | ||
} | ||
|
||
// Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Losing helpers like this is going to get annoying very quickly :( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to add a parameter There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tested it and it seems that we should add the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems lots of amounts of helper functions that like |
||
// hi result is in use in case we hit this mid-legalization. | ||
for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) { | ||
if (!LegalOperations || TLI.isOperationLegalOrCustom(LoHiOpc, VT)) { | ||
SDVTList LoHiVT = DAG.getVTList(VT, VT); | ||
// TODO: Can we match commutable operands with getNodeIfExists? | ||
if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N0, N1})) | ||
if (LoHi->hasAnyUseOfValue(1)) | ||
return SDValue(LoHi, 0); | ||
if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N1, N0})) | ||
if (LoHi->hasAnyUseOfValue(1)) | ||
return SDValue(LoHi, 0); | ||
if (!UseVP) { | ||
for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) { | ||
if (!LegalOperations || TLI.isOperationLegalOrCustom(LoHiOpc, VT)) { | ||
SDVTList LoHiVT = DAG.getVTList(VT, VT); | ||
// TODO: Can we match commutable operands with getNodeIfExists? | ||
if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N0, N1})) | ||
if (LoHi->hasAnyUseOfValue(1)) | ||
return SDValue(LoHi, 0); | ||
if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N1, N0})) | ||
if (LoHi->hasAnyUseOfValue(1)) | ||
return SDValue(LoHi, 0); | ||
} | ||
} | ||
} | ||
|
||
|
@@ -4439,7 +4446,8 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { | |
// x * 0xf800 --> (x << 16) - (x << 11) | ||
// x * -0x8800 --> -((x << 15) + (x << 11)) | ||
// x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16) | ||
if (N1IsConst && TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) { | ||
if (!UseVP && N1IsConst && | ||
TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) { | ||
// TODO: We could handle more general decomposition of any constant by | ||
// having the target set a limit on number of ops and making a | ||
// callback to determine that sequence (similar to sqrt expansion). | ||
|
@@ -4473,7 +4481,7 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { | |
} | ||
|
||
// (mul (shl X, c1), c2) -> (mul X, c2 << c1) | ||
if (N0.getOpcode() == ISD::SHL) { | ||
if (sd_context_match(N0, Matcher, m_Opc(ISD::SHL))) { | ||
SDValue N01 = N0.getOperand(1); | ||
if (SDValue C3 = DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N1, N01})) | ||
return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), C3); | ||
|
@@ -4485,42 +4493,41 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { | |
SDValue Sh, Y; | ||
|
||
// Check for both (mul (shl X, C), Y) and (mul Y, (shl X, C)). | ||
if (N0.getOpcode() == ISD::SHL && | ||
isConstantOrConstantVector(N0.getOperand(1)) && N0->hasOneUse()) { | ||
if (sd_context_match(N0, Matcher, m_OneUse(m_Opc(ISD::SHL))) && | ||
isConstantOrConstantVector(N0.getOperand(1))) { | ||
Sh = N0; Y = N1; | ||
} else if (N1.getOpcode() == ISD::SHL && | ||
isConstantOrConstantVector(N1.getOperand(1)) && | ||
N1->hasOneUse()) { | ||
} else if (sd_context_match(N1, Matcher, m_OneUse(m_Opc(ISD::SHL))) && | ||
isConstantOrConstantVector(N1.getOperand(1))) { | ||
Sh = N1; Y = N0; | ||
} | ||
|
||
if (Sh.getNode()) { | ||
SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y); | ||
return DAG.getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1)); | ||
SDValue Mul = Matcher.getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y); | ||
return Matcher.getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1)); | ||
} | ||
} | ||
|
||
// fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2) | ||
if (N0.getOpcode() == ISD::ADD && | ||
if (sd_context_match(N0, Matcher, m_Opc(ISD::ADD)) && | ||
DAG.isConstantIntBuildVectorOrConstantInt(N1) && | ||
DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) && | ||
isMulAddWithConstProfitable(N, N0, N1)) | ||
return DAG.getNode( | ||
return Matcher.getNode( | ||
ISD::ADD, DL, VT, | ||
DAG.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1), | ||
DAG.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1)); | ||
Matcher.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1), | ||
Matcher.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1)); | ||
|
||
// Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)). | ||
ConstantSDNode *NC1 = isConstOrConstSplat(N1); | ||
if (N0.getOpcode() == ISD::VSCALE && NC1) { | ||
if (!UseVP && N0.getOpcode() == ISD::VSCALE && NC1) { | ||
const APInt &C0 = N0.getConstantOperandAPInt(0); | ||
const APInt &C1 = NC1->getAPIntValue(); | ||
return DAG.getVScale(DL, VT, C0 * C1); | ||
} | ||
|
||
// Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)). | ||
APInt MulVal; | ||
if (N0.getOpcode() == ISD::STEP_VECTOR && | ||
if (!UseVP && N0.getOpcode() == ISD::STEP_VECTOR && | ||
ISD::isConstantSplatVector(N1.getNode(), MulVal)) { | ||
const APInt &C0 = N0.getConstantOperandAPInt(0); | ||
APInt NewStep = C0 * MulVal; | ||
|
@@ -4558,13 +4565,17 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { | |
} | ||
|
||
// reassociate mul | ||
if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags())) | ||
return RMUL; | ||
// TODO: Change reassociateOps to support vp ops. | ||
if (!UseVP) | ||
if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags())) | ||
return RMUL; | ||
|
||
// Fold mul(vecreduce(x), vecreduce(y)) -> vecreduce(mul(x, y)) | ||
if (SDValue SD = | ||
reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1)) | ||
return SD; | ||
// TODO: Change reassociateReduction to support vp ops. | ||
if (!UseVP) | ||
if (SDValue SD = | ||
reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1)) | ||
return SD; | ||
|
||
// Simplify the operands using demanded-bits information. | ||
if (SimplifyDemandedBits(SDValue(N, 0))) | ||
|
@@ -26693,6 +26704,10 @@ SDValue DAGCombiner::visitVPOp(SDNode *N) { | |
return visitFMA<VPMatchContext>(N); | ||
case ISD::VP_SELECT: | ||
return visitVP_SELECT(N); | ||
case ISD::VP_MUL: | ||
return visitMUL<VPMatchContext>(N); | ||
default: | ||
break; | ||
} | ||
return SDValue(); | ||
} | ||
|
@@ -27850,6 +27865,10 @@ static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT, | |
if (!VT.isVector()) | ||
return DAG.getConstant(Pow2Constants.back().logBase2(), DL, VT); | ||
// We need to create a build vector | ||
if (Op.getOpcode() == ISD::SPLAT_VECTOR) | ||
return DAG.getSplat(VT, DL, | ||
DAG.getConstant(Pow2Constants.back().logBase2(), DL, | ||
VT.getScalarType())); | ||
SmallVector<SDValue> Log2Ops; | ||
for (const APInt &Pow2 : Pow2Constants) | ||
Log2Ops.emplace_back( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't we previously use
bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to this form.