@@ -439,7 +439,7 @@ namespace {
439
439
SDValue visitSUBE(SDNode *N);
440
440
SDValue visitUSUBO_CARRY(SDNode *N);
441
441
SDValue visitSSUBO_CARRY(SDNode *N);
442
- SDValue visitMUL(SDNode *N);
442
+ template <class MatchContextClass> SDValue visitMUL(SDNode *N);
443
443
SDValue visitMULFIX(SDNode *N);
444
444
SDValue useDivRem(SDNode *N);
445
445
SDValue visitSDIV(SDNode *N);
@@ -1948,7 +1948,8 @@ SDValue DAGCombiner::visit(SDNode *N) {
1948
1948
case ISD::SMULFIXSAT:
1949
1949
case ISD::UMULFIX:
1950
1950
case ISD::UMULFIXSAT: return visitMULFIX(N);
1951
- case ISD::MUL: return visitMUL(N);
1951
+ case ISD::MUL:
1952
+ return visitMUL<EmptyMatchContext>(N);
1952
1953
case ISD::SDIV: return visitSDIV(N);
1953
1954
case ISD::UDIV: return visitUDIV(N);
1954
1955
case ISD::SREM:
@@ -4356,11 +4357,13 @@ SDValue DAGCombiner::visitMULFIX(SDNode *N) {
4356
4357
return SDValue();
4357
4358
}
4358
4359
4359
- SDValue DAGCombiner::visitMUL(SDNode *N) {
4360
+ template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
4360
4361
SDValue N0 = N->getOperand(0);
4361
4362
SDValue N1 = N->getOperand(1);
4362
4363
EVT VT = N0.getValueType();
4363
4364
SDLoc DL(N);
4365
+ bool IsVP = ISD::isVPOpcode(N->getOpcode());
4366
+ MatchContextClass matcher(DAG, TLI, N);
4364
4367
4365
4368
// fold (mul x, undef) -> 0
4366
4369
if (N0.isUndef() || N1.isUndef())
@@ -4373,16 +4376,18 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
4373
4376
// canonicalize constant to RHS (vector doesn't have to splat)
4374
4377
if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4375
4378
!DAG.isConstantIntBuildVectorOrConstantInt(N1))
4376
- return DAG .getNode(ISD::MUL, DL, VT, N1, N0);
4379
+ return matcher .getNode(ISD::MUL, DL, VT, N1, N0);
4377
4380
4378
4381
bool N1IsConst = false;
4379
4382
bool N1IsOpaqueConst = false;
4380
4383
APInt ConstValue1;
4381
4384
4382
4385
// fold vector ops
4383
4386
if (VT.isVector()) {
4384
- if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4385
- return FoldedVOp;
4387
+ // TODO: Change this to use SimplifyVBinOp when it supports VP op.
4388
+ if (!IsVP)
4389
+ if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4390
+ return FoldedVOp;
4386
4391
4387
4392
N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1);
4388
4393
assert((!N1IsConst ||
@@ -4404,20 +4409,21 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
4404
4409
if (N1IsConst && ConstValue1.isOne())
4405
4410
return N0;
4406
4411
4407
- if (SDValue NewSel = foldBinOpIntoSelect(N))
4408
- return NewSel;
4412
+ if (!IsVP)
4413
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
4414
+ return NewSel;
4409
4415
4410
4416
// fold (mul x, -1) -> 0-x
4411
4417
if (N1IsConst && ConstValue1.isAllOnes())
4412
- return DAG.getNegative(N0 , DL, VT);
4418
+ return matcher.getNode(ISD::SUB, DL, VT, DAG.getConstant(0 , DL, VT), N0 );
4413
4419
4414
4420
// fold (mul x, (1 << c)) -> x << c
4415
4421
if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4416
4422
(!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
4417
4423
if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
4418
4424
EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4419
4425
SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4420
- return DAG .getNode(ISD::SHL, DL, VT, N0, Trunc);
4426
+ return matcher .getNode(ISD::SHL, DL, VT, N0, Trunc);
4421
4427
}
4422
4428
}
4423
4429
@@ -4428,26 +4434,27 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
4428
4434
4429
4435
// FIXME: If the input is something that is easily negated (e.g. a
4430
4436
// single-use add), we should put the negate there.
4431
- return DAG .getNode(ISD::SUB, DL, VT,
4432
- DAG.getConstant(0, DL, VT),
4433
- DAG .getNode(ISD::SHL, DL, VT, N0,
4434
- DAG.getConstant(Log2Val, DL, ShiftVT)));
4437
+ return matcher .getNode(
4438
+ ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
4439
+ matcher .getNode(ISD::SHL, DL, VT, N0,
4440
+ DAG.getConstant(Log2Val, DL, ShiftVT)));
4435
4441
}
4436
4442
4437
4443
// Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
4438
4444
// hi result is in use in case we hit this mid-legalization.
4439
- for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4440
- if (!LegalOperations || TLI.isOperationLegalOrCustom(LoHiOpc, VT)) {
4441
- SDVTList LoHiVT = DAG.getVTList(VT, VT);
4442
- // TODO: Can we match commutable operands with getNodeIfExists?
4443
- if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N0, N1}))
4444
- if (LoHi->hasAnyUseOfValue(1))
4445
- return SDValue(LoHi, 0);
4446
- if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N1, N0}))
4447
- if (LoHi->hasAnyUseOfValue(1))
4448
- return SDValue(LoHi, 0);
4445
+ if (!IsVP)
4446
+ for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4447
+ if (!LegalOperations || TLI.isOperationLegalOrCustom(LoHiOpc, VT)) {
4448
+ SDVTList LoHiVT = DAG.getVTList(VT, VT);
4449
+ // TODO: Can we match commutable operands with getNodeIfExists?
4450
+ if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N0, N1}))
4451
+ if (LoHi->hasAnyUseOfValue(1))
4452
+ return SDValue(LoHi, 0);
4453
+ if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N1, N0}))
4454
+ if (LoHi->hasAnyUseOfValue(1))
4455
+ return SDValue(LoHi, 0);
4456
+ }
4449
4457
}
4450
- }
4451
4458
4452
4459
// Try to transform:
4453
4460
// (1) multiply-by-(power-of-2 +/- 1) into shift and add/sub.
@@ -4464,7 +4471,8 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
4464
4471
// x * 0xf800 --> (x << 16) - (x << 11)
4465
4472
// x * -0x8800 --> -((x << 15) + (x << 11))
4466
4473
// x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
4467
- if (N1IsConst && TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
4474
+ if (!IsVP && N1IsConst &&
4475
+ TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
4468
4476
// TODO: We could handle more general decomposition of any constant by
4469
4477
// having the target set a limit on number of ops and making a
4470
4478
// callback to determine that sequence (similar to sqrt expansion).
@@ -4498,7 +4506,7 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
4498
4506
}
4499
4507
4500
4508
// (mul (shl X, c1), c2) -> (mul X, c2 << c1)
4501
- if (N0.getOpcode() == ISD::SHL) {
4509
+ if (matcher.match(N0, ISD::SHL) ) {
4502
4510
SDValue N01 = N0.getOperand(1);
4503
4511
if (SDValue C3 = DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N1, N01}))
4504
4512
return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), C3);
@@ -4510,42 +4518,42 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
4510
4518
SDValue Sh, Y;
4511
4519
4512
4520
// Check for both (mul (shl X, C), Y) and (mul Y, (shl X, C)).
4513
- if (N0.getOpcode() == ISD::SHL &&
4521
+ if (matcher.match(N0, ISD::SHL) &&
4514
4522
isConstantOrConstantVector(N0.getOperand(1)) && N0->hasOneUse()) {
4515
4523
Sh = N0; Y = N1;
4516
- } else if (N1.getOpcode() == ISD::SHL &&
4524
+ } else if (matcher.match(N1, ISD::SHL) &&
4517
4525
isConstantOrConstantVector(N1.getOperand(1)) &&
4518
4526
N1->hasOneUse()) {
4519
4527
Sh = N1; Y = N0;
4520
4528
}
4521
4529
4522
4530
if (Sh.getNode()) {
4523
- SDValue Mul = DAG .getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y);
4524
- return DAG .getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1));
4531
+ SDValue Mul = matcher .getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y);
4532
+ return matcher .getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1));
4525
4533
}
4526
4534
}
4527
4535
4528
4536
// fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
4529
- if (N0.getOpcode() == ISD::ADD &&
4537
+ if (matcher.match(N0, ISD::ADD) &&
4530
4538
DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
4531
4539
DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) &&
4532
4540
isMulAddWithConstProfitable(N, N0, N1))
4533
- return DAG .getNode(
4541
+ return matcher .getNode(
4534
4542
ISD::ADD, DL, VT,
4535
- DAG .getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
4536
- DAG .getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
4543
+ matcher .getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
4544
+ matcher .getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
4537
4545
4538
4546
// Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
4539
4547
ConstantSDNode *NC1 = isConstOrConstSplat(N1);
4540
- if (N0.getOpcode() == ISD::VSCALE && NC1) {
4548
+ if (!IsVP && N0.getOpcode() == ISD::VSCALE && NC1) {
4541
4549
const APInt &C0 = N0.getConstantOperandAPInt(0);
4542
4550
const APInt &C1 = NC1->getAPIntValue();
4543
4551
return DAG.getVScale(DL, VT, C0 * C1);
4544
4552
}
4545
4553
4546
4554
// Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
4547
4555
APInt MulVal;
4548
- if (N0.getOpcode() == ISD::STEP_VECTOR &&
4556
+ if (!IsVP && N0.getOpcode() == ISD::STEP_VECTOR &&
4549
4557
ISD::isConstantSplatVector(N1.getNode(), MulVal)) {
4550
4558
const APInt &C0 = N0.getConstantOperandAPInt(0);
4551
4559
APInt NewStep = C0 * MulVal;
@@ -4583,13 +4591,17 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
4583
4591
}
4584
4592
4585
4593
// reassociate mul
4586
- if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
4587
- return RMUL;
4594
+ // TODO: Change reassociateOps to support vp ops.
4595
+ if (!IsVP)
4596
+ if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
4597
+ return RMUL;
4588
4598
4589
4599
// Fold mul(vecreduce(x), vecreduce(y)) -> vecreduce(mul(x, y))
4590
- if (SDValue SD =
4591
- reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1))
4592
- return SD;
4600
+ // TODO: Change reassociateReduction to support vp ops.
4601
+ if (!IsVP)
4602
+ if (SDValue SD =
4603
+ reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1))
4604
+ return SD;
4593
4605
4594
4606
// Simplify the operands using demanded-bits information.
4595
4607
if (SimplifyDemandedBits(SDValue(N, 0)))
@@ -26421,6 +26433,10 @@ SDValue DAGCombiner::visitVPOp(SDNode *N) {
26421
26433
return visitFMA<VPMatchContext>(N);
26422
26434
case ISD::VP_SELECT:
26423
26435
return visitVP_SELECT(N);
26436
+ case ISD::VP_MUL:
26437
+ return visitMUL<VPMatchContext>(N);
26438
+ default:
26439
+ break;
26424
26440
}
26425
26441
return SDValue();
26426
26442
}
@@ -27578,6 +27594,10 @@ static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT,
27578
27594
if (!VT.isVector())
27579
27595
return DAG.getConstant(Pow2Constants.back().logBase2(), DL, VT);
27580
27596
// We need to create a build vector
27597
+ if (Op.getOpcode() == ISD::SPLAT_VECTOR)
27598
+ return DAG.getSplat(VT, DL,
27599
+ DAG.getConstant(Pow2Constants.back().logBase2(), DL,
27600
+ VT.getScalarType()));
27581
27601
SmallVector<SDValue> Log2Ops;
27582
27602
for (const APInt &Pow2 : Pow2Constants)
27583
27603
Log2Ops.emplace_back(
0 commit comments