Skip to content

Commit 22921e6

Browse files
committed
[DAGCombiner][VP] Add DAGCombine for VP_MUL.
Use visitMUL to combine VP_MUL, share most logic of MUL with VP_MUL. Differential Revision: https://reviews.llvm.org/D121187
1 parent 8a98091 commit 22921e6

File tree

3 files changed

+601
-42
lines changed

3 files changed

+601
-42
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 62 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ namespace {
439439
SDValue visitSUBE(SDNode *N);
440440
SDValue visitUSUBO_CARRY(SDNode *N);
441441
SDValue visitSSUBO_CARRY(SDNode *N);
442-
SDValue visitMUL(SDNode *N);
442+
template <class MatchContextClass> SDValue visitMUL(SDNode *N);
443443
SDValue visitMULFIX(SDNode *N);
444444
SDValue useDivRem(SDNode *N);
445445
SDValue visitSDIV(SDNode *N);
@@ -1948,7 +1948,8 @@ SDValue DAGCombiner::visit(SDNode *N) {
19481948
case ISD::SMULFIXSAT:
19491949
case ISD::UMULFIX:
19501950
case ISD::UMULFIXSAT: return visitMULFIX(N);
1951-
case ISD::MUL: return visitMUL(N);
1951+
case ISD::MUL:
1952+
return visitMUL<EmptyMatchContext>(N);
19521953
case ISD::SDIV: return visitSDIV(N);
19531954
case ISD::UDIV: return visitUDIV(N);
19541955
case ISD::SREM:
@@ -4356,11 +4357,13 @@ SDValue DAGCombiner::visitMULFIX(SDNode *N) {
43564357
return SDValue();
43574358
}
43584359

4359-
SDValue DAGCombiner::visitMUL(SDNode *N) {
4360+
template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
43604361
SDValue N0 = N->getOperand(0);
43614362
SDValue N1 = N->getOperand(1);
43624363
EVT VT = N0.getValueType();
43634364
SDLoc DL(N);
4365+
bool IsVP = ISD::isVPOpcode(N->getOpcode());
4366+
MatchContextClass matcher(DAG, TLI, N);
43644367

43654368
// fold (mul x, undef) -> 0
43664369
if (N0.isUndef() || N1.isUndef())
@@ -4373,16 +4376,18 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
43734376
// canonicalize constant to RHS (vector doesn't have to splat)
43744377
if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
43754378
!DAG.isConstantIntBuildVectorOrConstantInt(N1))
4376-
return DAG.getNode(ISD::MUL, DL, VT, N1, N0);
4379+
return matcher.getNode(ISD::MUL, DL, VT, N1, N0);
43774380

43784381
bool N1IsConst = false;
43794382
bool N1IsOpaqueConst = false;
43804383
APInt ConstValue1;
43814384

43824385
// fold vector ops
43834386
if (VT.isVector()) {
4384-
if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4385-
return FoldedVOp;
4387+
// TODO: Change this to use SimplifyVBinOp when it supports VP op.
4388+
if (!IsVP)
4389+
if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4390+
return FoldedVOp;
43864391

43874392
N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1);
43884393
assert((!N1IsConst ||
@@ -4404,20 +4409,21 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
44044409
if (N1IsConst && ConstValue1.isOne())
44054410
return N0;
44064411

4407-
if (SDValue NewSel = foldBinOpIntoSelect(N))
4408-
return NewSel;
4412+
if (!IsVP)
4413+
if (SDValue NewSel = foldBinOpIntoSelect(N))
4414+
return NewSel;
44094415

44104416
// fold (mul x, -1) -> 0-x
44114417
if (N1IsConst && ConstValue1.isAllOnes())
4412-
return DAG.getNegative(N0, DL, VT);
4418+
return matcher.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0);
44134419

44144420
// fold (mul x, (1 << c)) -> x << c
44154421
if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
44164422
(!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
44174423
if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
44184424
EVT ShiftVT = getShiftAmountTy(N0.getValueType());
44194425
SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4420-
return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc);
4426+
return matcher.getNode(ISD::SHL, DL, VT, N0, Trunc);
44214427
}
44224428
}
44234429

@@ -4428,26 +4434,27 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
44284434

44294435
// FIXME: If the input is something that is easily negated (e.g. a
44304436
// single-use add), we should put the negate there.
4431-
return DAG.getNode(ISD::SUB, DL, VT,
4432-
DAG.getConstant(0, DL, VT),
4433-
DAG.getNode(ISD::SHL, DL, VT, N0,
4434-
DAG.getConstant(Log2Val, DL, ShiftVT)));
4437+
return matcher.getNode(
4438+
ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
4439+
matcher.getNode(ISD::SHL, DL, VT, N0,
4440+
DAG.getConstant(Log2Val, DL, ShiftVT)));
44354441
}
44364442

44374443
// Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
44384444
// hi result is in use in case we hit this mid-legalization.
4439-
for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4440-
if (!LegalOperations || TLI.isOperationLegalOrCustom(LoHiOpc, VT)) {
4441-
SDVTList LoHiVT = DAG.getVTList(VT, VT);
4442-
// TODO: Can we match commutable operands with getNodeIfExists?
4443-
if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N0, N1}))
4444-
if (LoHi->hasAnyUseOfValue(1))
4445-
return SDValue(LoHi, 0);
4446-
if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N1, N0}))
4447-
if (LoHi->hasAnyUseOfValue(1))
4448-
return SDValue(LoHi, 0);
4445+
if (!IsVP)
4446+
for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4447+
if (!LegalOperations || TLI.isOperationLegalOrCustom(LoHiOpc, VT)) {
4448+
SDVTList LoHiVT = DAG.getVTList(VT, VT);
4449+
// TODO: Can we match commutable operands with getNodeIfExists?
4450+
if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N0, N1}))
4451+
if (LoHi->hasAnyUseOfValue(1))
4452+
return SDValue(LoHi, 0);
4453+
if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N1, N0}))
4454+
if (LoHi->hasAnyUseOfValue(1))
4455+
return SDValue(LoHi, 0);
4456+
}
44494457
}
4450-
}
44514458

44524459
// Try to transform:
44534460
// (1) multiply-by-(power-of-2 +/- 1) into shift and add/sub.
@@ -4464,7 +4471,8 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
44644471
// x * 0xf800 --> (x << 16) - (x << 11)
44654472
// x * -0x8800 --> -((x << 15) + (x << 11))
44664473
// x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
4467-
if (N1IsConst && TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
4474+
if (!IsVP && N1IsConst &&
4475+
TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
44684476
// TODO: We could handle more general decomposition of any constant by
44694477
// having the target set a limit on number of ops and making a
44704478
// callback to determine that sequence (similar to sqrt expansion).
@@ -4498,7 +4506,7 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
44984506
}
44994507

45004508
// (mul (shl X, c1), c2) -> (mul X, c2 << c1)
4501-
if (N0.getOpcode() == ISD::SHL) {
4509+
if (matcher.match(N0, ISD::SHL)) {
45024510
SDValue N01 = N0.getOperand(1);
45034511
if (SDValue C3 = DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N1, N01}))
45044512
return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), C3);
@@ -4510,42 +4518,42 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
45104518
SDValue Sh, Y;
45114519

45124520
// Check for both (mul (shl X, C), Y) and (mul Y, (shl X, C)).
4513-
if (N0.getOpcode() == ISD::SHL &&
4521+
if (matcher.match(N0, ISD::SHL) &&
45144522
isConstantOrConstantVector(N0.getOperand(1)) && N0->hasOneUse()) {
45154523
Sh = N0; Y = N1;
4516-
} else if (N1.getOpcode() == ISD::SHL &&
4524+
} else if (matcher.match(N1, ISD::SHL) &&
45174525
isConstantOrConstantVector(N1.getOperand(1)) &&
45184526
N1->hasOneUse()) {
45194527
Sh = N1; Y = N0;
45204528
}
45214529

45224530
if (Sh.getNode()) {
4523-
SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y);
4524-
return DAG.getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1));
4531+
SDValue Mul = matcher.getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y);
4532+
return matcher.getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1));
45254533
}
45264534
}
45274535

45284536
// fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
4529-
if (N0.getOpcode() == ISD::ADD &&
4537+
if (matcher.match(N0, ISD::ADD) &&
45304538
DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
45314539
DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) &&
45324540
isMulAddWithConstProfitable(N, N0, N1))
4533-
return DAG.getNode(
4541+
return matcher.getNode(
45344542
ISD::ADD, DL, VT,
4535-
DAG.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
4536-
DAG.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
4543+
matcher.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
4544+
matcher.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
45374545

45384546
// Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
45394547
ConstantSDNode *NC1 = isConstOrConstSplat(N1);
4540-
if (N0.getOpcode() == ISD::VSCALE && NC1) {
4548+
if (!IsVP && N0.getOpcode() == ISD::VSCALE && NC1) {
45414549
const APInt &C0 = N0.getConstantOperandAPInt(0);
45424550
const APInt &C1 = NC1->getAPIntValue();
45434551
return DAG.getVScale(DL, VT, C0 * C1);
45444552
}
45454553

45464554
// Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
45474555
APInt MulVal;
4548-
if (N0.getOpcode() == ISD::STEP_VECTOR &&
4556+
if (!IsVP && N0.getOpcode() == ISD::STEP_VECTOR &&
45494557
ISD::isConstantSplatVector(N1.getNode(), MulVal)) {
45504558
const APInt &C0 = N0.getConstantOperandAPInt(0);
45514559
APInt NewStep = C0 * MulVal;
@@ -4583,13 +4591,17 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
45834591
}
45844592

45854593
// reassociate mul
4586-
if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
4587-
return RMUL;
4594+
// TODO: Change reassociateOps to support vp ops.
4595+
if (!IsVP)
4596+
if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
4597+
return RMUL;
45884598

45894599
// Fold mul(vecreduce(x), vecreduce(y)) -> vecreduce(mul(x, y))
4590-
if (SDValue SD =
4591-
reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1))
4592-
return SD;
4600+
// TODO: Change reassociateReduction to support vp ops.
4601+
if (!IsVP)
4602+
if (SDValue SD =
4603+
reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1))
4604+
return SD;
45934605

45944606
// Simplify the operands using demanded-bits information.
45954607
if (SimplifyDemandedBits(SDValue(N, 0)))
@@ -26421,6 +26433,10 @@ SDValue DAGCombiner::visitVPOp(SDNode *N) {
2642126433
return visitFMA<VPMatchContext>(N);
2642226434
case ISD::VP_SELECT:
2642326435
return visitVP_SELECT(N);
26436+
case ISD::VP_MUL:
26437+
return visitMUL<VPMatchContext>(N);
26438+
default:
26439+
break;
2642426440
}
2642526441
return SDValue();
2642626442
}
@@ -27578,6 +27594,10 @@ static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT,
2757827594
if (!VT.isVector())
2757927595
return DAG.getConstant(Pow2Constants.back().logBase2(), DL, VT);
2758027596
// We need to create a build vector
27597+
if (Op.getOpcode() == ISD::SPLAT_VECTOR)
27598+
return DAG.getSplat(VT, DL,
27599+
DAG.getConstant(Pow2Constants.back().logBase2(), DL,
27600+
VT.getScalarType()));
2758127601
SmallVector<SDValue> Log2Ops;
2758227602
for (const APInt &Pow2 : Pow2Constants)
2758327603
Log2Ops.emplace_back(

0 commit comments

Comments
 (0)