@@ -4455,6 +4455,34 @@ SDValue TargetLowering::BuildUDIV(SDNode *N, SelectionDAG &DAG,
4455
4455
return DAG.getSelect (dl, VT, IsOne, N0, Q);
4456
4456
}
4457
4457
4458
+ // / If all values in Values that *don't* match the predicate are same 'splat'
4459
+ // / value, then replace all values with that splat value.
4460
+ // / Else, if AlternativeReplacement was provided, then replace all values that
4461
+ // / do match predicate with AlternativeReplacement value.
4462
+ static void
4463
+ turnVectorIntoSplatVector (MutableArrayRef<SDValue> Values,
4464
+ std::function<bool (SDValue)> Predicate,
4465
+ SDValue AlternativeReplacement = SDValue()) {
4466
+ SDValue Replacement;
4467
+ // Is there a value for which the Predicate does *NOT* match? What is it?
4468
+ auto SplatValue = llvm::find_if_not (Values, Predicate);
4469
+ if (SplatValue != Values.end ()) {
4470
+ // Does Values consist only of SplatValue's and values matching Predicate?
4471
+ if (llvm::all_of (Values, [Predicate, SplatValue](SDValue Value) {
4472
+ return Value == *SplatValue || Predicate (Value);
4473
+ })) // Then we shall replace values matching predicate with SplatValue.
4474
+ Replacement = *SplatValue;
4475
+ }
4476
+ if (!Replacement) {
4477
+ // Oops, we did not find the "baseline" splat value.
4478
+ if (!AlternativeReplacement)
4479
+ return ; // Nothing to do.
4480
+ // Let's replace with provided value then.
4481
+ Replacement = AlternativeReplacement;
4482
+ }
4483
+ std::replace_if (Values.begin (), Values.end (), Predicate, Replacement);
4484
+ }
4485
+
4458
4486
// / Given an ISD::UREM used only by an ISD::SETEQ or ISD::SETNE
4459
4487
// / where the divisor is constant and the comparison target is zero,
4460
4488
// / return a DAG expression that will generate the same comparison result
@@ -4482,74 +4510,143 @@ TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode,
4482
4510
DAGCombinerInfo &DCI, const SDLoc &DL,
4483
4511
SmallVectorImpl<SDNode *> &Created) const {
4484
4512
// fold (seteq/ne (urem N, D), 0) -> (setule/ugt (rotr (mul N, P), K), Q)
4485
- // - D must be constant with D = D0 * 2^K where D0 is odd and D0 != 1
4513
+ // - D must be constant, with D = D0 * 2^K where D0 is odd
4486
4514
// - P is the multiplicative inverse of D0 modulo 2^W
4487
4515
// - Q = floor((2^W - 1) / D0)
4488
4516
// where W is the width of the common type of N and D.
4489
4517
assert ((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
4490
4518
" Only applicable for (in)equality comparisons." );
4491
4519
4520
+ SelectionDAG &DAG = DCI.DAG ;
4521
+
4492
4522
EVT VT = REMNode.getValueType ();
4523
+ EVT SVT = VT.getScalarType ();
4524
+ EVT ShVT = getShiftAmountTy (VT, DAG.getDataLayout ());
4525
+ EVT ShSVT = ShVT.getScalarType ();
4493
4526
4494
4527
// If MUL is unavailable, we cannot proceed in any case.
4495
4528
if (!isOperationLegalOrCustom (ISD::MUL, VT))
4496
4529
return SDValue ();
4497
4530
4498
- // TODO: Add non-uniform constant support.
4499
- ConstantSDNode *Divisor = isConstOrConstSplat (REMNode->getOperand (1 ));
4531
+ // TODO: Could support comparing with non-zero too.
4500
4532
ConstantSDNode *CompTarget = isConstOrConstSplat (CompTargetNode);
4501
- if (!Divisor || !CompTarget || Divisor->isNullValue () ||
4502
- !CompTarget->isNullValue ())
4533
+ if (!CompTarget || !CompTarget->isNullValue ())
4503
4534
return SDValue ();
4504
4535
4505
- const APInt &D = Divisor->getAPIntValue ();
4536
+ bool HadOneDivisor = false ;
4537
+ bool AllDivisorsAreOnes = true ;
4538
+ bool HadEvenDivisor = false ;
4539
+ bool AllDivisorsArePowerOfTwo = true ;
4540
+ SmallVector<SDValue, 16 > PAmts, KAmts, QAmts;
4541
+
4542
+ auto BuildUREMPattern = [&](ConstantSDNode *C) {
4543
+ // Division by 0 is UB. Leave it to be constant-folded elsewhere.
4544
+ if (C->isNullValue ())
4545
+ return false ;
4506
4546
4507
- // Decompose D into D0 * 2^K
4508
- unsigned K = D.countTrailingZeros ();
4509
- bool DivisorIsEven = (K != 0 );
4510
- APInt D0 = D.lshr (K);
4547
+ const APInt &D = C->getAPIntValue ();
4548
+ // If all divisors are ones, we will prefer to avoid the fold.
4549
+ HadOneDivisor |= D.isOneValue ();
4550
+ AllDivisorsAreOnes &= D.isOneValue ();
4551
+
4552
+ // Decompose D into D0 * 2^K
4553
+ unsigned K = D.countTrailingZeros ();
4554
+ assert ((!D.isOneValue () || (K == 0 )) && " For divisor '1' we won't rotate." );
4555
+ APInt D0 = D.lshr (K);
4556
+
4557
+ // D is even if it has trailing zeros.
4558
+ HadEvenDivisor |= (K != 0 );
4559
+ // D is a power-of-two if D0 is one.
4560
+ // If all divisors are power-of-two, we will prefer to avoid the fold.
4561
+ AllDivisorsArePowerOfTwo &= D0.isOneValue ();
4562
+
4563
+ // P = inv(D0, 2^W)
4564
+ // 2^W requires W + 1 bits, so we have to extend and then truncate.
4565
+ unsigned W = D.getBitWidth ();
4566
+ APInt P = D0.zext (W + 1 )
4567
+ .multiplicativeInverse (APInt::getSignedMinValue (W + 1 ))
4568
+ .trunc (W);
4569
+ assert (!P.isNullValue () && " No multiplicative inverse!" ); // unreachable
4570
+ assert ((D0 * P).isOneValue () && " Multiplicative inverse sanity check." );
4571
+
4572
+ // Q = floor((2^W - 1) / D)
4573
+ APInt Q = APInt::getAllOnesValue (W).udiv (D);
4574
+
4575
+ assert (APInt::getAllOnesValue (ShSVT.getSizeInBits ()).ugt (K) &&
4576
+ " We are expecting that K is always less than all-ones for ShSVT" );
4577
+
4578
+ // If the divisor is 1 the result can be constant-folded.
4579
+ if (D.isOneValue ()) {
4580
+ // Set P and K amount to a bogus values so we can try to splat them.
4581
+ P = 0 ;
4582
+ K = -1 ;
4583
+ assert (Q.isAllOnesValue () &&
4584
+ " Expecting all-ones comparison for one divisor" );
4585
+ }
4511
4586
4512
- // The fold is invalid when D0 == 1.
4513
- // This is reachable because visitSetCC happens before visitREM.
4514
- if (D0.isOneValue ())
4587
+ PAmts.push_back (DAG.getConstant (P, DL, SVT));
4588
+ KAmts.push_back (
4589
+ DAG.getConstant (APInt (ShSVT.getSizeInBits (), K), DL, ShSVT));
4590
+ QAmts.push_back (DAG.getConstant (Q, DL, SVT));
4591
+ return true ;
4592
+ };
4593
+
4594
+ SDValue N = REMNode.getOperand (0 );
4595
+ SDValue D = REMNode.getOperand (1 );
4596
+
4597
+ // Collect the values from each element.
4598
+ if (!ISD::matchUnaryPredicate (D, BuildUREMPattern))
4515
4599
return SDValue ();
4516
4600
4517
- // P = inv(D0, 2^W)
4518
- // 2^W requires W + 1 bits, so we have to extend and then truncate.
4519
- unsigned W = D.getBitWidth ();
4520
- APInt P = D0.zext (W + 1 )
4521
- .multiplicativeInverse (APInt::getSignedMinValue (W + 1 ))
4522
- .trunc (W);
4523
- assert (!P.isNullValue () && " No multiplicative inverse!" ); // unreachable
4524
- assert ((D0 * P).isOneValue () && " Multiplicative inverse sanity check." );
4601
+ // If this is a urem by a one, avoid the fold since it can be constant-folded.
4602
+ if (AllDivisorsAreOnes)
4603
+ return SDValue ();
4525
4604
4526
- // Q = floor((2^W - 1) / D)
4527
- APInt Q = APInt::getAllOnesValue (W).udiv (D);
4605
+ // If this is a urem by a powers-of-two, avoid the fold since it can be
4606
+ // best implemented as a bit test.
4607
+ if (AllDivisorsArePowerOfTwo)
4608
+ return SDValue ();
4528
4609
4529
- SelectionDAG &DAG = DCI.DAG ;
4610
+ SDValue PVal, KVal, QVal;
4611
+ if (VT.isVector ()) {
4612
+ if (HadOneDivisor) {
4613
+ // Try to turn PAmts into a splat, since we don't care about the values
4614
+ // that are currently '0'. If we can't, just keep '0'`s.
4615
+ turnVectorIntoSplatVector (PAmts, isNullConstant);
4616
+ // Try to turn KAmts into a splat, since we don't care about the values
4617
+ // that are currently '-1'. If we can't, change them to '0'`s.
4618
+ turnVectorIntoSplatVector (KAmts, isAllOnesConstant,
4619
+ DAG.getConstant (0 , DL, ShSVT));
4620
+ }
4621
+
4622
+ PVal = DAG.getBuildVector (VT, DL, PAmts);
4623
+ KVal = DAG.getBuildVector (ShVT, DL, KAmts);
4624
+ QVal = DAG.getBuildVector (VT, DL, QAmts);
4625
+ } else {
4626
+ PVal = PAmts[0 ];
4627
+ KVal = KAmts[0 ];
4628
+ QVal = QAmts[0 ];
4629
+ }
4530
4630
4531
- SDValue PVal = DAG.getConstant (P, DL, VT);
4532
- SDValue QVal = DAG.getConstant (Q, DL, VT);
4533
4631
// (mul N, P)
4534
- SDValue Op1 = DAG.getNode (ISD::MUL, DL, VT, REMNode-> getOperand ( 0 ) , PVal);
4535
- Created.push_back (Op1 .getNode ());
4632
+ SDValue Op0 = DAG.getNode (ISD::MUL, DL, VT, N , PVal);
4633
+ Created.push_back (Op0 .getNode ());
4536
4634
4537
- // Rotate right only if D was even.
4538
- if (DivisorIsEven) {
4635
+ // Rotate right only if any divisor was even. We avoid rotates for all-odd
4636
+ // divisors as a performance improvement, since rotating by 0 is a no-op.
4637
+ if (HadEvenDivisor) {
4539
4638
// We need ROTR to do this.
4540
4639
if (!isOperationLegalOrCustom (ISD::ROTR, VT))
4541
4640
return SDValue ();
4542
- SDValue ShAmt =
4543
- DAG.getConstant (K, DL, getShiftAmountTy (VT, DAG.getDataLayout ()));
4544
4641
SDNodeFlags Flags;
4545
4642
Flags.setExact (true );
4546
4643
// UREM: (rotr (mul N, P), K)
4547
- Op1 = DAG.getNode (ISD::ROTR, DL, VT, Op1, ShAmt , Flags);
4548
- Created.push_back (Op1 .getNode ());
4644
+ Op0 = DAG.getNode (ISD::ROTR, DL, VT, Op0, KVal , Flags);
4645
+ Created.push_back (Op0 .getNode ());
4549
4646
}
4550
4647
4551
4648
// UREM: (setule/setugt (rotr (mul N, P), K), Q)
4552
- return DAG.getSetCC (DL, SETCCVT, Op1 , QVal,
4649
+ return DAG.getSetCC (DL, SETCCVT, Op0 , QVal,
4553
4650
((Cond == ISD::SETEQ) ? ISD::SETULE : ISD::SETUGT));
4554
4651
}
4555
4652
0 commit comments