@@ -438,7 +438,7 @@ namespace {
438
438
SDValue visitSUBE(SDNode *N);
439
439
SDValue visitUSUBO_CARRY(SDNode *N);
440
440
SDValue visitSSUBO_CARRY(SDNode *N);
441
- SDValue visitMUL(SDNode *N);
441
+ template <class MatchContextClass> SDValue visitMUL(SDNode *N);
442
442
SDValue visitMULFIX(SDNode *N);
443
443
SDValue useDivRem(SDNode *N);
444
444
SDValue visitSDIV(SDNode *N);
@@ -1855,7 +1855,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
1855
1855
case ISD::SMULFIXSAT:
1856
1856
case ISD::UMULFIX:
1857
1857
case ISD::UMULFIXSAT: return visitMULFIX(N);
1858
- case ISD::MUL: return visitMUL(N);
1858
+ case ISD::MUL: return visitMUL<EmptyMatchContext> (N);
1859
1859
case ISD::SDIV: return visitSDIV(N);
1860
1860
case ISD::UDIV: return visitUDIV(N);
1861
1861
case ISD::SREM:
@@ -4331,11 +4331,13 @@ SDValue DAGCombiner::visitMULFIX(SDNode *N) {
4331
4331
return SDValue();
4332
4332
}
4333
4333
4334
- SDValue DAGCombiner::visitMUL(SDNode *N) {
4334
+ template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
4335
4335
SDValue N0 = N->getOperand(0);
4336
4336
SDValue N1 = N->getOperand(1);
4337
4337
EVT VT = N0.getValueType();
4338
4338
SDLoc DL(N);
4339
+ bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
4340
+ MatchContextClass Matcher(DAG, TLI, N);
4339
4341
4340
4342
// fold (mul x, undef) -> 0
4341
4343
if (N0.isUndef() || N1.isUndef())
@@ -4348,16 +4350,18 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
4348
4350
// canonicalize constant to RHS (vector doesn't have to splat)
4349
4351
if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4350
4352
!DAG.isConstantIntBuildVectorOrConstantInt(N1))
4351
- return DAG .getNode(ISD::MUL, DL, VT, N1, N0);
4353
+ return Matcher .getNode(ISD::MUL, DL, VT, N1, N0);
4352
4354
4353
4355
bool N1IsConst = false;
4354
4356
bool N1IsOpaqueConst = false;
4355
4357
APInt ConstValue1;
4356
4358
4357
4359
// fold vector ops
4358
4360
if (VT.isVector()) {
4359
- if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4360
- return FoldedVOp;
4361
+ // TODO: Change this to use SimplifyVBinOp when it supports VP op.
4362
+ if (!UseVP)
4363
+ if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4364
+ return FoldedVOp;
4361
4365
4362
4366
N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1);
4363
4367
assert((!N1IsConst ||
@@ -4379,20 +4383,21 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
4379
4383
if (N1IsConst && ConstValue1.isOne())
4380
4384
return N0;
4381
4385
4382
- if (SDValue NewSel = foldBinOpIntoSelect(N))
4383
- return NewSel;
4386
+ if (!UseVP)
4387
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
4388
+ return NewSel;
4384
4389
4385
4390
// fold (mul x, -1) -> 0-x
4386
4391
if (N1IsConst && ConstValue1.isAllOnes())
4387
- return DAG.getNegative(N0 , DL, VT);
4392
+ return Matcher.getNode(ISD::SUB, DL, VT, DAG.getConstant(0 , DL, VT), N0 );
4388
4393
4389
4394
// fold (mul x, (1 << c)) -> x << c
4390
4395
if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4391
4396
(!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
4392
4397
if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
4393
4398
EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4394
4399
SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4395
- return DAG .getNode(ISD::SHL, DL, VT, N0, Trunc);
4400
+ return Matcher .getNode(ISD::SHL, DL, VT, N0, Trunc);
4396
4401
}
4397
4402
}
4398
4403
@@ -4403,24 +4408,26 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
4403
4408
4404
4409
// FIXME: If the input is something that is easily negated (e.g. a
4405
4410
// single-use add), we should put the negate there.
4406
- return DAG .getNode(ISD::SUB, DL, VT,
4407
- DAG.getConstant(0, DL, VT),
4408
- DAG .getNode(ISD::SHL, DL, VT, N0,
4409
- DAG.getConstant(Log2Val, DL, ShiftVT)));
4411
+ return Matcher .getNode(
4412
+ ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
4413
+ Matcher .getNode(ISD::SHL, DL, VT, N0,
4414
+ DAG.getConstant(Log2Val, DL, ShiftVT)));
4410
4415
}
4411
4416
4412
4417
// Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
4413
4418
// hi result is in use in case we hit this mid-legalization.
4414
- for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4415
- if (!LegalOperations || TLI.isOperationLegalOrCustom(LoHiOpc, VT)) {
4416
- SDVTList LoHiVT = DAG.getVTList(VT, VT);
4417
- // TODO: Can we match commutable operands with getNodeIfExists?
4418
- if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N0, N1}))
4419
- if (LoHi->hasAnyUseOfValue(1))
4420
- return SDValue(LoHi, 0);
4421
- if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N1, N0}))
4422
- if (LoHi->hasAnyUseOfValue(1))
4423
- return SDValue(LoHi, 0);
4419
+ if (!UseVP) {
4420
+ for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4421
+ if (!LegalOperations || TLI.isOperationLegalOrCustom(LoHiOpc, VT)) {
4422
+ SDVTList LoHiVT = DAG.getVTList(VT, VT);
4423
+ // TODO: Can we match commutable operands with getNodeIfExists?
4424
+ if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N0, N1}))
4425
+ if (LoHi->hasAnyUseOfValue(1))
4426
+ return SDValue(LoHi, 0);
4427
+ if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N1, N0}))
4428
+ if (LoHi->hasAnyUseOfValue(1))
4429
+ return SDValue(LoHi, 0);
4430
+ }
4424
4431
}
4425
4432
}
4426
4433
@@ -4439,7 +4446,8 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
4439
4446
// x * 0xf800 --> (x << 16) - (x << 11)
4440
4447
// x * -0x8800 --> -((x << 15) + (x << 11))
4441
4448
// x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
4442
- if (N1IsConst && TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
4449
+ if (!UseVP && N1IsConst &&
4450
+ TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
4443
4451
// TODO: We could handle more general decomposition of any constant by
4444
4452
// having the target set a limit on number of ops and making a
4445
4453
// callback to determine that sequence (similar to sqrt expansion).
@@ -4473,7 +4481,7 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
4473
4481
}
4474
4482
4475
4483
// (mul (shl X, c1), c2) -> (mul X, c2 << c1)
4476
- if (N0.getOpcode() == ISD::SHL) {
4484
+ if (sd_context_match(N0, Matcher, m_Opc( ISD::SHL)) ) {
4477
4485
SDValue N01 = N0.getOperand(1);
4478
4486
if (SDValue C3 = DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N1, N01}))
4479
4487
return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), C3);
@@ -4485,42 +4493,41 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
4485
4493
SDValue Sh, Y;
4486
4494
4487
4495
// Check for both (mul (shl X, C), Y) and (mul Y, (shl X, C)).
4488
- if (N0.getOpcode() == ISD::SHL &&
4489
- isConstantOrConstantVector(N0.getOperand(1)) && N0->hasOneUse() ) {
4496
+ if (sd_context_match(N0, Matcher, m_OneUse(m_Opc( ISD::SHL))) &&
4497
+ isConstantOrConstantVector(N0.getOperand(1))) {
4490
4498
Sh = N0; Y = N1;
4491
- } else if (N1.getOpcode() == ISD::SHL &&
4492
- isConstantOrConstantVector(N1.getOperand(1)) &&
4493
- N1->hasOneUse()) {
4499
+ } else if (sd_context_match(N1, Matcher, m_OneUse(m_Opc(ISD::SHL))) &&
4500
+ isConstantOrConstantVector(N1.getOperand(1))) {
4494
4501
Sh = N1; Y = N0;
4495
4502
}
4496
4503
4497
4504
if (Sh.getNode()) {
4498
- SDValue Mul = DAG .getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y);
4499
- return DAG .getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1));
4505
+ SDValue Mul = Matcher .getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y);
4506
+ return Matcher .getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1));
4500
4507
}
4501
4508
}
4502
4509
4503
4510
// fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
4504
- if (N0.getOpcode() == ISD::ADD &&
4511
+ if (sd_context_match(N0, Matcher, m_Opc( ISD::ADD)) &&
4505
4512
DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
4506
4513
DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) &&
4507
4514
isMulAddWithConstProfitable(N, N0, N1))
4508
- return DAG .getNode(
4515
+ return Matcher .getNode(
4509
4516
ISD::ADD, DL, VT,
4510
- DAG .getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
4511
- DAG .getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
4517
+ Matcher .getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
4518
+ Matcher .getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
4512
4519
4513
4520
// Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
4514
4521
ConstantSDNode *NC1 = isConstOrConstSplat(N1);
4515
- if (N0.getOpcode() == ISD::VSCALE && NC1) {
4522
+ if (!UseVP && N0.getOpcode() == ISD::VSCALE && NC1) {
4516
4523
const APInt &C0 = N0.getConstantOperandAPInt(0);
4517
4524
const APInt &C1 = NC1->getAPIntValue();
4518
4525
return DAG.getVScale(DL, VT, C0 * C1);
4519
4526
}
4520
4527
4521
4528
// Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
4522
4529
APInt MulVal;
4523
- if (N0.getOpcode() == ISD::STEP_VECTOR &&
4530
+ if (!UseVP && N0.getOpcode() == ISD::STEP_VECTOR &&
4524
4531
ISD::isConstantSplatVector(N1.getNode(), MulVal)) {
4525
4532
const APInt &C0 = N0.getConstantOperandAPInt(0);
4526
4533
APInt NewStep = C0 * MulVal;
@@ -4558,13 +4565,17 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
4558
4565
}
4559
4566
4560
4567
// reassociate mul
4561
- if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
4562
- return RMUL;
4568
+ // TODO: Change reassociateOps to support vp ops.
4569
+ if (!UseVP)
4570
+ if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
4571
+ return RMUL;
4563
4572
4564
4573
// Fold mul(vecreduce(x), vecreduce(y)) -> vecreduce(mul(x, y))
4565
- if (SDValue SD =
4566
- reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1))
4567
- return SD;
4574
+ // TODO: Change reassociateReduction to support vp ops.
4575
+ if (!UseVP)
4576
+ if (SDValue SD =
4577
+ reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1))
4578
+ return SD;
4568
4579
4569
4580
// Simplify the operands using demanded-bits information.
4570
4581
if (SimplifyDemandedBits(SDValue(N, 0)))
@@ -26693,6 +26704,10 @@ SDValue DAGCombiner::visitVPOp(SDNode *N) {
26693
26704
return visitFMA<VPMatchContext>(N);
26694
26705
case ISD::VP_SELECT:
26695
26706
return visitVP_SELECT(N);
26707
+ case ISD::VP_MUL:
26708
+ return visitMUL<VPMatchContext>(N);
26709
+ default:
26710
+ break;
26696
26711
}
26697
26712
return SDValue();
26698
26713
}
@@ -27850,6 +27865,10 @@ static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT,
27850
27865
if (!VT.isVector())
27851
27866
return DAG.getConstant(Pow2Constants.back().logBase2(), DL, VT);
27852
27867
// We need to create a build vector
27868
+ if (Op.getOpcode() == ISD::SPLAT_VECTOR)
27869
+ return DAG.getSplat(VT, DL,
27870
+ DAG.getConstant(Pow2Constants.back().logBase2(), DL,
27871
+ VT.getScalarType()));
27853
27872
SmallVector<SDValue> Log2Ops;
27854
27873
for (const APInt &Pow2 : Pow2Constants)
27855
27874
Log2Ops.emplace_back(
0 commit comments