Skip to content

[DAGCombiner][VP] Add DAGCombine for VP_MUL #80105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 63 additions & 44 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ namespace {
SDValue visitSUBE(SDNode *N);
SDValue visitUSUBO_CARRY(SDNode *N);
SDValue visitSSUBO_CARRY(SDNode *N);
SDValue visitMUL(SDNode *N);
template <class MatchContextClass> SDValue visitMUL(SDNode *N);
SDValue visitMULFIX(SDNode *N);
SDValue useDivRem(SDNode *N);
SDValue visitSDIV(SDNode *N);
Expand Down Expand Up @@ -1855,7 +1855,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::SMULFIXSAT:
case ISD::UMULFIX:
case ISD::UMULFIXSAT: return visitMULFIX(N);
case ISD::MUL: return visitMUL(N);
case ISD::MUL: return visitMUL<EmptyMatchContext>(N);
case ISD::SDIV: return visitSDIV(N);
case ISD::UDIV: return visitUDIV(N);
case ISD::SREM:
Expand Down Expand Up @@ -4331,11 +4331,13 @@ SDValue DAGCombiner::visitMULFIX(SDNode *N) {
return SDValue();
}

SDValue DAGCombiner::visitMUL(SDNode *N) {
template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
EVT VT = N0.getValueType();
SDLoc DL(N);
bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
MatchContextClass Matcher(DAG, TLI, N);

// fold (mul x, undef) -> 0
if (N0.isUndef() || N1.isUndef())
Expand All @@ -4348,16 +4350,18 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
// canonicalize constant to RHS (vector doesn't have to splat)
if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
!DAG.isConstantIntBuildVectorOrConstantInt(N1))
return DAG.getNode(ISD::MUL, DL, VT, N1, N0);
return Matcher.getNode(ISD::MUL, DL, VT, N1, N0);

bool N1IsConst = false;
bool N1IsOpaqueConst = false;
APInt ConstValue1;

// fold vector ops
if (VT.isVector()) {
if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
return FoldedVOp;
// TODO: Change this to use SimplifyVBinOp when it supports VP op.
if (!UseVP)
if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
return FoldedVOp;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't we previously use bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>; ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to this form.


N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1);
assert((!N1IsConst ||
Expand All @@ -4379,20 +4383,21 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
if (N1IsConst && ConstValue1.isOne())
return N0;

if (SDValue NewSel = foldBinOpIntoSelect(N))
return NewSel;
if (!UseVP)
if (SDValue NewSel = foldBinOpIntoSelect(N))
return NewSel;

// fold (mul x, -1) -> 0-x
if (N1IsConst && ConstValue1.isAllOnes())
return DAG.getNegative(N0, DL, VT);
return Matcher.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0);

// fold (mul x, (1 << c)) -> x << c
if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
(!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
EVT ShiftVT = getShiftAmountTy(N0.getValueType());
SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc);
return Matcher.getNode(ISD::SHL, DL, VT, N0, Trunc);
}
}

Expand All @@ -4403,24 +4408,26 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {

// FIXME: If the input is something that is easily negated (e.g. a
// single-use add), we should put the negate there.
return DAG.getNode(ISD::SUB, DL, VT,
DAG.getConstant(0, DL, VT),
DAG.getNode(ISD::SHL, DL, VT, N0,
DAG.getConstant(Log2Val, DL, ShiftVT)));
return Matcher.getNode(
ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about the correctness, but this patch has to be rebase first.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is equivalent to no change because of !IsVP here.

Matcher.getNode(ISD::SHL, DL, VT, N0,
DAG.getConstant(Log2Val, DL, ShiftVT)));
}

// Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Losing helpers like this is going to get annoying very quickly :(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add a parameter bool IsVP = false to getNegative?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested it and it seems that we should add the getNegative function for EmptyMatchContext and VPMatchContext.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems lots of amounts of helper functions that like getNegative, maybe need to find a way to reuse the same logic to avoid duplicated code.

// hi result is in use in case we hit this mid-legalization.
for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
if (!LegalOperations || TLI.isOperationLegalOrCustom(LoHiOpc, VT)) {
SDVTList LoHiVT = DAG.getVTList(VT, VT);
// TODO: Can we match commutable operands with getNodeIfExists?
if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N0, N1}))
if (LoHi->hasAnyUseOfValue(1))
return SDValue(LoHi, 0);
if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N1, N0}))
if (LoHi->hasAnyUseOfValue(1))
return SDValue(LoHi, 0);
if (!UseVP) {
for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
if (!LegalOperations || TLI.isOperationLegalOrCustom(LoHiOpc, VT)) {
SDVTList LoHiVT = DAG.getVTList(VT, VT);
// TODO: Can we match commutable operands with getNodeIfExists?
if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N0, N1}))
if (LoHi->hasAnyUseOfValue(1))
return SDValue(LoHi, 0);
if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N1, N0}))
if (LoHi->hasAnyUseOfValue(1))
return SDValue(LoHi, 0);
}
}
}

Expand All @@ -4439,7 +4446,8 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
// x * 0xf800 --> (x << 16) - (x << 11)
// x * -0x8800 --> -((x << 15) + (x << 11))
// x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
if (N1IsConst && TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
if (!UseVP && N1IsConst &&
TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
// TODO: We could handle more general decomposition of any constant by
// having the target set a limit on number of ops and making a
// callback to determine that sequence (similar to sqrt expansion).
Expand Down Expand Up @@ -4473,7 +4481,7 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
}

// (mul (shl X, c1), c2) -> (mul X, c2 << c1)
if (N0.getOpcode() == ISD::SHL) {
if (sd_context_match(N0, Matcher, m_Opc(ISD::SHL))) {
SDValue N01 = N0.getOperand(1);
if (SDValue C3 = DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N1, N01}))
return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), C3);
Expand All @@ -4485,42 +4493,41 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
SDValue Sh, Y;

// Check for both (mul (shl X, C), Y) and (mul Y, (shl X, C)).
if (N0.getOpcode() == ISD::SHL &&
isConstantOrConstantVector(N0.getOperand(1)) && N0->hasOneUse()) {
if (sd_context_match(N0, Matcher, m_OneUse(m_Opc(ISD::SHL))) &&
isConstantOrConstantVector(N0.getOperand(1))) {
Sh = N0; Y = N1;
} else if (N1.getOpcode() == ISD::SHL &&
isConstantOrConstantVector(N1.getOperand(1)) &&
N1->hasOneUse()) {
} else if (sd_context_match(N1, Matcher, m_OneUse(m_Opc(ISD::SHL))) &&
isConstantOrConstantVector(N1.getOperand(1))) {
Sh = N1; Y = N0;
}

if (Sh.getNode()) {
SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y);
return DAG.getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1));
SDValue Mul = Matcher.getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y);
return Matcher.getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1));
}
}

// fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
if (N0.getOpcode() == ISD::ADD &&
if (sd_context_match(N0, Matcher, m_Opc(ISD::ADD)) &&
DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) &&
isMulAddWithConstProfitable(N, N0, N1))
return DAG.getNode(
return Matcher.getNode(
ISD::ADD, DL, VT,
DAG.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
DAG.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
Matcher.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
Matcher.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));

// Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
ConstantSDNode *NC1 = isConstOrConstSplat(N1);
if (N0.getOpcode() == ISD::VSCALE && NC1) {
if (!UseVP && N0.getOpcode() == ISD::VSCALE && NC1) {
const APInt &C0 = N0.getConstantOperandAPInt(0);
const APInt &C1 = NC1->getAPIntValue();
return DAG.getVScale(DL, VT, C0 * C1);
}

// Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
APInt MulVal;
if (N0.getOpcode() == ISD::STEP_VECTOR &&
if (!UseVP && N0.getOpcode() == ISD::STEP_VECTOR &&
ISD::isConstantSplatVector(N1.getNode(), MulVal)) {
const APInt &C0 = N0.getConstantOperandAPInt(0);
APInt NewStep = C0 * MulVal;
Expand Down Expand Up @@ -4558,13 +4565,17 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
}

// reassociate mul
if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
return RMUL;
// TODO: Change reassociateOps to support vp ops.
if (!UseVP)
if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
return RMUL;

// Fold mul(vecreduce(x), vecreduce(y)) -> vecreduce(mul(x, y))
if (SDValue SD =
reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1))
return SD;
// TODO: Change reassociateReduction to support vp ops.
if (!UseVP)
if (SDValue SD =
reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1))
return SD;

// Simplify the operands using demanded-bits information.
if (SimplifyDemandedBits(SDValue(N, 0)))
Expand Down Expand Up @@ -26693,6 +26704,10 @@ SDValue DAGCombiner::visitVPOp(SDNode *N) {
return visitFMA<VPMatchContext>(N);
case ISD::VP_SELECT:
return visitVP_SELECT(N);
case ISD::VP_MUL:
return visitMUL<VPMatchContext>(N);
default:
break;
}
return SDValue();
}
Expand Down Expand Up @@ -27850,6 +27865,10 @@ static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT,
if (!VT.isVector())
return DAG.getConstant(Pow2Constants.back().logBase2(), DL, VT);
// We need to create a build vector
if (Op.getOpcode() == ISD::SPLAT_VECTOR)
return DAG.getSplat(VT, DL,
DAG.getConstant(Pow2Constants.back().logBase2(), DL,
VT.getScalarType()));
SmallVector<SDValue> Log2Ops;
for (const APInt &Pow2 : Pow2Constants)
Log2Ops.emplace_back(
Expand Down
Loading
Loading