Skip to content

Commit db6de1a

Browse files
authored
[DAGCombiner][VP] Add DAGCombine for VP_MUL (#80105)
Use visitMUL to combine VP_MUL, share most logic of MUL with VP_MUL. Migrate from https://reviews.llvm.org/D121187
1 parent aa98c75 commit db6de1a

File tree

3 files changed

+602
-44
lines changed

3 files changed

+602
-44
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 63 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ namespace {
438438
SDValue visitSUBE(SDNode *N);
439439
SDValue visitUSUBO_CARRY(SDNode *N);
440440
SDValue visitSSUBO_CARRY(SDNode *N);
441-
SDValue visitMUL(SDNode *N);
441+
template <class MatchContextClass> SDValue visitMUL(SDNode *N);
442442
SDValue visitMULFIX(SDNode *N);
443443
SDValue useDivRem(SDNode *N);
444444
SDValue visitSDIV(SDNode *N);
@@ -1855,7 +1855,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
18551855
case ISD::SMULFIXSAT:
18561856
case ISD::UMULFIX:
18571857
case ISD::UMULFIXSAT: return visitMULFIX(N);
1858-
case ISD::MUL: return visitMUL(N);
1858+
case ISD::MUL: return visitMUL<EmptyMatchContext>(N);
18591859
case ISD::SDIV: return visitSDIV(N);
18601860
case ISD::UDIV: return visitUDIV(N);
18611861
case ISD::SREM:
@@ -4331,11 +4331,13 @@ SDValue DAGCombiner::visitMULFIX(SDNode *N) {
43314331
return SDValue();
43324332
}
43334333

4334-
SDValue DAGCombiner::visitMUL(SDNode *N) {
4334+
template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
43354335
SDValue N0 = N->getOperand(0);
43364336
SDValue N1 = N->getOperand(1);
43374337
EVT VT = N0.getValueType();
43384338
SDLoc DL(N);
4339+
bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
4340+
MatchContextClass Matcher(DAG, TLI, N);
43394341

43404342
// fold (mul x, undef) -> 0
43414343
if (N0.isUndef() || N1.isUndef())
@@ -4348,16 +4350,18 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
43484350
// canonicalize constant to RHS (vector doesn't have to splat)
43494351
if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
43504352
!DAG.isConstantIntBuildVectorOrConstantInt(N1))
4351-
return DAG.getNode(ISD::MUL, DL, VT, N1, N0);
4353+
return Matcher.getNode(ISD::MUL, DL, VT, N1, N0);
43524354

43534355
bool N1IsConst = false;
43544356
bool N1IsOpaqueConst = false;
43554357
APInt ConstValue1;
43564358

43574359
// fold vector ops
43584360
if (VT.isVector()) {
4359-
if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4360-
return FoldedVOp;
4361+
// TODO: Change this to use SimplifyVBinOp when it supports VP op.
4362+
if (!UseVP)
4363+
if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4364+
return FoldedVOp;
43614365

43624366
N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1);
43634367
assert((!N1IsConst ||
@@ -4379,20 +4383,21 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
43794383
if (N1IsConst && ConstValue1.isOne())
43804384
return N0;
43814385

4382-
if (SDValue NewSel = foldBinOpIntoSelect(N))
4383-
return NewSel;
4386+
if (!UseVP)
4387+
if (SDValue NewSel = foldBinOpIntoSelect(N))
4388+
return NewSel;
43844389

43854390
// fold (mul x, -1) -> 0-x
43864391
if (N1IsConst && ConstValue1.isAllOnes())
4387-
return DAG.getNegative(N0, DL, VT);
4392+
return Matcher.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0);
43884393

43894394
// fold (mul x, (1 << c)) -> x << c
43904395
if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
43914396
(!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
43924397
if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
43934398
EVT ShiftVT = getShiftAmountTy(N0.getValueType());
43944399
SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4395-
return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc);
4400+
return Matcher.getNode(ISD::SHL, DL, VT, N0, Trunc);
43964401
}
43974402
}
43984403

@@ -4403,24 +4408,26 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
44034408

44044409
// FIXME: If the input is something that is easily negated (e.g. a
44054410
// single-use add), we should put the negate there.
4406-
return DAG.getNode(ISD::SUB, DL, VT,
4407-
DAG.getConstant(0, DL, VT),
4408-
DAG.getNode(ISD::SHL, DL, VT, N0,
4409-
DAG.getConstant(Log2Val, DL, ShiftVT)));
4411+
return Matcher.getNode(
4412+
ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
4413+
Matcher.getNode(ISD::SHL, DL, VT, N0,
4414+
DAG.getConstant(Log2Val, DL, ShiftVT)));
44104415
}
44114416

44124417
// Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
44134418
// hi result is in use in case we hit this mid-legalization.
4414-
for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4415-
if (!LegalOperations || TLI.isOperationLegalOrCustom(LoHiOpc, VT)) {
4416-
SDVTList LoHiVT = DAG.getVTList(VT, VT);
4417-
// TODO: Can we match commutable operands with getNodeIfExists?
4418-
if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N0, N1}))
4419-
if (LoHi->hasAnyUseOfValue(1))
4420-
return SDValue(LoHi, 0);
4421-
if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N1, N0}))
4422-
if (LoHi->hasAnyUseOfValue(1))
4423-
return SDValue(LoHi, 0);
4419+
if (!UseVP) {
4420+
for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4421+
if (!LegalOperations || TLI.isOperationLegalOrCustom(LoHiOpc, VT)) {
4422+
SDVTList LoHiVT = DAG.getVTList(VT, VT);
4423+
// TODO: Can we match commutable operands with getNodeIfExists?
4424+
if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N0, N1}))
4425+
if (LoHi->hasAnyUseOfValue(1))
4426+
return SDValue(LoHi, 0);
4427+
if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N1, N0}))
4428+
if (LoHi->hasAnyUseOfValue(1))
4429+
return SDValue(LoHi, 0);
4430+
}
44244431
}
44254432
}
44264433

@@ -4439,7 +4446,8 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
44394446
// x * 0xf800 --> (x << 16) - (x << 11)
44404447
// x * -0x8800 --> -((x << 15) + (x << 11))
44414448
// x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
4442-
if (N1IsConst && TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
4449+
if (!UseVP && N1IsConst &&
4450+
TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
44434451
// TODO: We could handle more general decomposition of any constant by
44444452
// having the target set a limit on number of ops and making a
44454453
// callback to determine that sequence (similar to sqrt expansion).
@@ -4473,7 +4481,7 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
44734481
}
44744482

44754483
// (mul (shl X, c1), c2) -> (mul X, c2 << c1)
4476-
if (N0.getOpcode() == ISD::SHL) {
4484+
if (sd_context_match(N0, Matcher, m_Opc(ISD::SHL))) {
44774485
SDValue N01 = N0.getOperand(1);
44784486
if (SDValue C3 = DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N1, N01}))
44794487
return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), C3);
@@ -4485,42 +4493,41 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
44854493
SDValue Sh, Y;
44864494

44874495
// Check for both (mul (shl X, C), Y) and (mul Y, (shl X, C)).
4488-
if (N0.getOpcode() == ISD::SHL &&
4489-
isConstantOrConstantVector(N0.getOperand(1)) && N0->hasOneUse()) {
4496+
if (sd_context_match(N0, Matcher, m_OneUse(m_Opc(ISD::SHL))) &&
4497+
isConstantOrConstantVector(N0.getOperand(1))) {
44904498
Sh = N0; Y = N1;
4491-
} else if (N1.getOpcode() == ISD::SHL &&
4492-
isConstantOrConstantVector(N1.getOperand(1)) &&
4493-
N1->hasOneUse()) {
4499+
} else if (sd_context_match(N1, Matcher, m_OneUse(m_Opc(ISD::SHL))) &&
4500+
isConstantOrConstantVector(N1.getOperand(1))) {
44944501
Sh = N1; Y = N0;
44954502
}
44964503

44974504
if (Sh.getNode()) {
4498-
SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y);
4499-
return DAG.getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1));
4505+
SDValue Mul = Matcher.getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y);
4506+
return Matcher.getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1));
45004507
}
45014508
}
45024509

45034510
// fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
4504-
if (N0.getOpcode() == ISD::ADD &&
4511+
if (sd_context_match(N0, Matcher, m_Opc(ISD::ADD)) &&
45054512
DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
45064513
DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) &&
45074514
isMulAddWithConstProfitable(N, N0, N1))
4508-
return DAG.getNode(
4515+
return Matcher.getNode(
45094516
ISD::ADD, DL, VT,
4510-
DAG.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
4511-
DAG.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
4517+
Matcher.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
4518+
Matcher.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
45124519

45134520
// Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
45144521
ConstantSDNode *NC1 = isConstOrConstSplat(N1);
4515-
if (N0.getOpcode() == ISD::VSCALE && NC1) {
4522+
if (!UseVP && N0.getOpcode() == ISD::VSCALE && NC1) {
45164523
const APInt &C0 = N0.getConstantOperandAPInt(0);
45174524
const APInt &C1 = NC1->getAPIntValue();
45184525
return DAG.getVScale(DL, VT, C0 * C1);
45194526
}
45204527

45214528
// Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
45224529
APInt MulVal;
4523-
if (N0.getOpcode() == ISD::STEP_VECTOR &&
4530+
if (!UseVP && N0.getOpcode() == ISD::STEP_VECTOR &&
45244531
ISD::isConstantSplatVector(N1.getNode(), MulVal)) {
45254532
const APInt &C0 = N0.getConstantOperandAPInt(0);
45264533
APInt NewStep = C0 * MulVal;
@@ -4558,13 +4565,17 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
45584565
}
45594566

45604567
// reassociate mul
4561-
if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
4562-
return RMUL;
4568+
// TODO: Change reassociateOps to support vp ops.
4569+
if (!UseVP)
4570+
if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
4571+
return RMUL;
45634572

45644573
// Fold mul(vecreduce(x), vecreduce(y)) -> vecreduce(mul(x, y))
4565-
if (SDValue SD =
4566-
reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1))
4567-
return SD;
4574+
// TODO: Change reassociateReduction to support vp ops.
4575+
if (!UseVP)
4576+
if (SDValue SD =
4577+
reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1))
4578+
return SD;
45684579

45694580
// Simplify the operands using demanded-bits information.
45704581
if (SimplifyDemandedBits(SDValue(N, 0)))
@@ -26693,6 +26704,10 @@ SDValue DAGCombiner::visitVPOp(SDNode *N) {
2669326704
return visitFMA<VPMatchContext>(N);
2669426705
case ISD::VP_SELECT:
2669526706
return visitVP_SELECT(N);
26707+
case ISD::VP_MUL:
26708+
return visitMUL<VPMatchContext>(N);
26709+
default:
26710+
break;
2669626711
}
2669726712
return SDValue();
2669826713
}
@@ -27850,6 +27865,10 @@ static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT,
2785027865
if (!VT.isVector())
2785127866
return DAG.getConstant(Pow2Constants.back().logBase2(), DL, VT);
2785227867
// We need to create a build vector
27868+
if (Op.getOpcode() == ISD::SPLAT_VECTOR)
27869+
return DAG.getSplat(VT, DL,
27870+
DAG.getConstant(Pow2Constants.back().logBase2(), DL,
27871+
VT.getScalarType()));
2785327872
SmallVector<SDValue> Log2Ops;
2785427873
for (const APInt &Pow2 : Pow2Constants)
2785527874
Log2Ops.emplace_back(

0 commit comments

Comments
 (0)