Skip to content

Commit cd9b194

Browse files
committed
[Codegen][SelectionDAG] X u% C == 0 fold: non-splat vector improvements
Summary: Four things here: 1. Generalize the fold to handle non-splat divisors. Reasonably trivial. 2. Unban power-of-two divisors. I don't see any reason why they should be illegal. * There is no ban in Hacker's Delight * I think the ban came from the same bug that caused the miscompile in the base patch - in `floor((2^W - 1) / D)` we were dividing by `D0` instead of `D`, and we **were** ensuring that `D0` is not `1`, which made sense. 3. Unban `1` divisors. I no longer believe Hacker's Delight actually says that the fold is invalid for `D = 0`. Further considerations: * We know that * `(X u% 1) == 0` can be constant-folded to `1`, * `(X u% 1) != 0` can be constant-folded to `0`, * Also, we know that * `X u<= -1` can be constant-folded to `1`, * `X u> -1` can be constant-folded to `0`, * https://godbolt.org/z/7jnZJX https://rise4fun.com/Alive/oF6p * We know will end up with the following: `(setule/setugt (rotr (mul N, P), K), Q)` * Therefore, for given new DAG nodes and comparison predicates (`ule`/`ugt`), we will still produce the correct answer if: `Q` is a all-ones constant; and both `P` and `K` are *anything* other than `undef`. * The fold will indeed produce `Q = all-ones`. 4. Try to re-splat the `P` and `K` vectors - we don't care about their values for the lanes where divisor was `1`. Reviewers: RKSimon, hermord, craig.topper, spatel, xbolva00 Reviewed By: RKSimon Subscribers: hiraditya, javed.absar, dexonsmith, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D63963 llvm-svn: 366637
1 parent adec0f2 commit cd9b194

File tree

3 files changed

+289
-666
lines changed

3 files changed

+289
-666
lines changed

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 132 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4455,6 +4455,34 @@ SDValue TargetLowering::BuildUDIV(SDNode *N, SelectionDAG &DAG,
44554455
return DAG.getSelect(dl, VT, IsOne, N0, Q);
44564456
}
44574457

4458+
/// If all values in Values that *don't* match the predicate are same 'splat'
4459+
/// value, then replace all values with that splat value.
4460+
/// Else, if AlternativeReplacement was provided, then replace all values that
4461+
/// do match predicate with AlternativeReplacement value.
4462+
static void
4463+
turnVectorIntoSplatVector(MutableArrayRef<SDValue> Values,
4464+
std::function<bool(SDValue)> Predicate,
4465+
SDValue AlternativeReplacement = SDValue()) {
4466+
SDValue Replacement;
4467+
// Is there a value for which the Predicate does *NOT* match? What is it?
4468+
auto SplatValue = llvm::find_if_not(Values, Predicate);
4469+
if (SplatValue != Values.end()) {
4470+
// Does Values consist only of SplatValue's and values matching Predicate?
4471+
if (llvm::all_of(Values, [Predicate, SplatValue](SDValue Value) {
4472+
return Value == *SplatValue || Predicate(Value);
4473+
})) // Then we shall replace values matching predicate with SplatValue.
4474+
Replacement = *SplatValue;
4475+
}
4476+
if (!Replacement) {
4477+
// Oops, we did not find the "baseline" splat value.
4478+
if (!AlternativeReplacement)
4479+
return; // Nothing to do.
4480+
// Let's replace with provided value then.
4481+
Replacement = AlternativeReplacement;
4482+
}
4483+
std::replace_if(Values.begin(), Values.end(), Predicate, Replacement);
4484+
}
4485+
44584486
/// Given an ISD::UREM used only by an ISD::SETEQ or ISD::SETNE
44594487
/// where the divisor is constant and the comparison target is zero,
44604488
/// return a DAG expression that will generate the same comparison result
@@ -4482,74 +4510,143 @@ TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode,
44824510
DAGCombinerInfo &DCI, const SDLoc &DL,
44834511
SmallVectorImpl<SDNode *> &Created) const {
44844512
// fold (seteq/ne (urem N, D), 0) -> (setule/ugt (rotr (mul N, P), K), Q)
4485-
// - D must be constant with D = D0 * 2^K where D0 is odd and D0 != 1
4513+
// - D must be constant, with D = D0 * 2^K where D0 is odd
44864514
// - P is the multiplicative inverse of D0 modulo 2^W
44874515
// - Q = floor((2^W - 1) / D0)
44884516
// where W is the width of the common type of N and D.
44894517
assert((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
44904518
"Only applicable for (in)equality comparisons.");
44914519

4520+
SelectionDAG &DAG = DCI.DAG;
4521+
44924522
EVT VT = REMNode.getValueType();
4523+
EVT SVT = VT.getScalarType();
4524+
EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
4525+
EVT ShSVT = ShVT.getScalarType();
44934526

44944527
// If MUL is unavailable, we cannot proceed in any case.
44954528
if (!isOperationLegalOrCustom(ISD::MUL, VT))
44964529
return SDValue();
44974530

4498-
// TODO: Add non-uniform constant support.
4499-
ConstantSDNode *Divisor = isConstOrConstSplat(REMNode->getOperand(1));
4531+
// TODO: Could support comparing with non-zero too.
45004532
ConstantSDNode *CompTarget = isConstOrConstSplat(CompTargetNode);
4501-
if (!Divisor || !CompTarget || Divisor->isNullValue() ||
4502-
!CompTarget->isNullValue())
4533+
if (!CompTarget || !CompTarget->isNullValue())
45034534
return SDValue();
45044535

4505-
const APInt &D = Divisor->getAPIntValue();
4536+
bool HadOneDivisor = false;
4537+
bool AllDivisorsAreOnes = true;
4538+
bool HadEvenDivisor = false;
4539+
bool AllDivisorsArePowerOfTwo = true;
4540+
SmallVector<SDValue, 16> PAmts, KAmts, QAmts;
4541+
4542+
auto BuildUREMPattern = [&](ConstantSDNode *C) {
4543+
// Division by 0 is UB. Leave it to be constant-folded elsewhere.
4544+
if (C->isNullValue())
4545+
return false;
45064546

4507-
// Decompose D into D0 * 2^K
4508-
unsigned K = D.countTrailingZeros();
4509-
bool DivisorIsEven = (K != 0);
4510-
APInt D0 = D.lshr(K);
4547+
const APInt &D = C->getAPIntValue();
4548+
// If all divisors are ones, we will prefer to avoid the fold.
4549+
HadOneDivisor |= D.isOneValue();
4550+
AllDivisorsAreOnes &= D.isOneValue();
4551+
4552+
// Decompose D into D0 * 2^K
4553+
unsigned K = D.countTrailingZeros();
4554+
assert((!D.isOneValue() || (K == 0)) && "For divisor '1' we won't rotate.");
4555+
APInt D0 = D.lshr(K);
4556+
4557+
// D is even if it has trailing zeros.
4558+
HadEvenDivisor |= (K != 0);
4559+
// D is a power-of-two if D0 is one.
4560+
// If all divisors are power-of-two, we will prefer to avoid the fold.
4561+
AllDivisorsArePowerOfTwo &= D0.isOneValue();
4562+
4563+
// P = inv(D0, 2^W)
4564+
// 2^W requires W + 1 bits, so we have to extend and then truncate.
4565+
unsigned W = D.getBitWidth();
4566+
APInt P = D0.zext(W + 1)
4567+
.multiplicativeInverse(APInt::getSignedMinValue(W + 1))
4568+
.trunc(W);
4569+
assert(!P.isNullValue() && "No multiplicative inverse!"); // unreachable
4570+
assert((D0 * P).isOneValue() && "Multiplicative inverse sanity check.");
4571+
4572+
// Q = floor((2^W - 1) / D)
4573+
APInt Q = APInt::getAllOnesValue(W).udiv(D);
4574+
4575+
assert(APInt::getAllOnesValue(ShSVT.getSizeInBits()).ugt(K) &&
4576+
"We are expecting that K is always less than all-ones for ShSVT");
4577+
4578+
// If the divisor is 1 the result can be constant-folded.
4579+
if (D.isOneValue()) {
4580+
// Set P and K amount to a bogus values so we can try to splat them.
4581+
P = 0;
4582+
K = -1;
4583+
assert(Q.isAllOnesValue() &&
4584+
"Expecting all-ones comparison for one divisor");
4585+
}
45114586

4512-
// The fold is invalid when D0 == 1.
4513-
// This is reachable because visitSetCC happens before visitREM.
4514-
if (D0.isOneValue())
4587+
PAmts.push_back(DAG.getConstant(P, DL, SVT));
4588+
KAmts.push_back(
4589+
DAG.getConstant(APInt(ShSVT.getSizeInBits(), K), DL, ShSVT));
4590+
QAmts.push_back(DAG.getConstant(Q, DL, SVT));
4591+
return true;
4592+
};
4593+
4594+
SDValue N = REMNode.getOperand(0);
4595+
SDValue D = REMNode.getOperand(1);
4596+
4597+
// Collect the values from each element.
4598+
if (!ISD::matchUnaryPredicate(D, BuildUREMPattern))
45154599
return SDValue();
45164600

4517-
// P = inv(D0, 2^W)
4518-
// 2^W requires W + 1 bits, so we have to extend and then truncate.
4519-
unsigned W = D.getBitWidth();
4520-
APInt P = D0.zext(W + 1)
4521-
.multiplicativeInverse(APInt::getSignedMinValue(W + 1))
4522-
.trunc(W);
4523-
assert(!P.isNullValue() && "No multiplicative inverse!"); // unreachable
4524-
assert((D0 * P).isOneValue() && "Multiplicative inverse sanity check.");
4601+
// If this is a urem by a one, avoid the fold since it can be constant-folded.
4602+
if (AllDivisorsAreOnes)
4603+
return SDValue();
45254604

4526-
// Q = floor((2^W - 1) / D)
4527-
APInt Q = APInt::getAllOnesValue(W).udiv(D);
4605+
// If this is a urem by a powers-of-two, avoid the fold since it can be
4606+
// best implemented as a bit test.
4607+
if (AllDivisorsArePowerOfTwo)
4608+
return SDValue();
45284609

4529-
SelectionDAG &DAG = DCI.DAG;
4610+
SDValue PVal, KVal, QVal;
4611+
if (VT.isVector()) {
4612+
if (HadOneDivisor) {
4613+
// Try to turn PAmts into a splat, since we don't care about the values
4614+
// that are currently '0'. If we can't, just keep '0'`s.
4615+
turnVectorIntoSplatVector(PAmts, isNullConstant);
4616+
// Try to turn KAmts into a splat, since we don't care about the values
4617+
// that are currently '-1'. If we can't, change them to '0'`s.
4618+
turnVectorIntoSplatVector(KAmts, isAllOnesConstant,
4619+
DAG.getConstant(0, DL, ShSVT));
4620+
}
4621+
4622+
PVal = DAG.getBuildVector(VT, DL, PAmts);
4623+
KVal = DAG.getBuildVector(ShVT, DL, KAmts);
4624+
QVal = DAG.getBuildVector(VT, DL, QAmts);
4625+
} else {
4626+
PVal = PAmts[0];
4627+
KVal = KAmts[0];
4628+
QVal = QAmts[0];
4629+
}
45304630

4531-
SDValue PVal = DAG.getConstant(P, DL, VT);
4532-
SDValue QVal = DAG.getConstant(Q, DL, VT);
45334631
// (mul N, P)
4534-
SDValue Op1 = DAG.getNode(ISD::MUL, DL, VT, REMNode->getOperand(0), PVal);
4535-
Created.push_back(Op1.getNode());
4632+
SDValue Op0 = DAG.getNode(ISD::MUL, DL, VT, N, PVal);
4633+
Created.push_back(Op0.getNode());
45364634

4537-
// Rotate right only if D was even.
4538-
if (DivisorIsEven) {
4635+
// Rotate right only if any divisor was even. We avoid rotates for all-odd
4636+
// divisors as a performance improvement, since rotating by 0 is a no-op.
4637+
if (HadEvenDivisor) {
45394638
// We need ROTR to do this.
45404639
if (!isOperationLegalOrCustom(ISD::ROTR, VT))
45414640
return SDValue();
4542-
SDValue ShAmt =
4543-
DAG.getConstant(K, DL, getShiftAmountTy(VT, DAG.getDataLayout()));
45444641
SDNodeFlags Flags;
45454642
Flags.setExact(true);
45464643
// UREM: (rotr (mul N, P), K)
4547-
Op1 = DAG.getNode(ISD::ROTR, DL, VT, Op1, ShAmt, Flags);
4548-
Created.push_back(Op1.getNode());
4644+
Op0 = DAG.getNode(ISD::ROTR, DL, VT, Op0, KVal, Flags);
4645+
Created.push_back(Op0.getNode());
45494646
}
45504647

45514648
// UREM: (setule/setugt (rotr (mul N, P), K), Q)
4552-
return DAG.getSetCC(DL, SETCCVT, Op1, QVal,
4649+
return DAG.getSetCC(DL, SETCCVT, Op0, QVal,
45534650
((Cond == ISD::SETEQ) ? ISD::SETULE : ISD::SETUGT));
45544651
}
45554652

llvm/test/CodeGen/AArch64/urem-seteq-vec-nonsplat.ll

Lines changed: 17 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,11 @@ define <4 x i32> @test_urem_odd_allones_eq(<4 x i32> %X) nounwind {
4040
; CHECK-LABEL: test_urem_odd_allones_eq:
4141
; CHECK: // %bb.0:
4242
; CHECK-NEXT: adrp x8, .LCPI1_0
43+
; CHECK-NEXT: adrp x9, .LCPI1_1
4344
; CHECK-NEXT: ldr q1, [x8, :lo12:.LCPI1_0]
44-
; CHECK-NEXT: adrp x8, .LCPI1_1
45-
; CHECK-NEXT: ldr q2, [x8, :lo12:.LCPI1_1]
46-
; CHECK-NEXT: adrp x8, .LCPI1_2
47-
; CHECK-NEXT: ldr q3, [x8, :lo12:.LCPI1_2]
48-
; CHECK-NEXT: umull2 v4.2d, v0.4s, v1.4s
49-
; CHECK-NEXT: umull v1.2d, v0.2s, v1.2s
50-
; CHECK-NEXT: uzp2 v1.4s, v1.4s, v4.4s
51-
; CHECK-NEXT: neg v2.4s, v2.4s
52-
; CHECK-NEXT: ushl v1.4s, v1.4s, v2.4s
53-
; CHECK-NEXT: mls v0.4s, v1.4s, v3.4s
54-
; CHECK-NEXT: cmeq v0.4s, v0.4s, #0
45+
; CHECK-NEXT: ldr q2, [x9, :lo12:.LCPI1_1]
46+
; CHECK-NEXT: mul v0.4s, v0.4s, v1.4s
47+
; CHECK-NEXT: cmhs v0.4s, v2.4s, v0.4s
5548
; CHECK-NEXT: movi v1.4s, #1
5649
; CHECK-NEXT: and v0.16b, v0.16b, v1.16b
5750
; CHECK-NEXT: ret
@@ -64,19 +57,11 @@ define <4 x i32> @test_urem_odd_allones_ne(<4 x i32> %X) nounwind {
6457
; CHECK-LABEL: test_urem_odd_allones_ne:
6558
; CHECK: // %bb.0:
6659
; CHECK-NEXT: adrp x8, .LCPI2_0
60+
; CHECK-NEXT: adrp x9, .LCPI2_1
6761
; CHECK-NEXT: ldr q1, [x8, :lo12:.LCPI2_0]
68-
; CHECK-NEXT: adrp x8, .LCPI2_1
69-
; CHECK-NEXT: ldr q2, [x8, :lo12:.LCPI2_1]
70-
; CHECK-NEXT: adrp x8, .LCPI2_2
71-
; CHECK-NEXT: ldr q3, [x8, :lo12:.LCPI2_2]
72-
; CHECK-NEXT: umull2 v4.2d, v0.4s, v1.4s
73-
; CHECK-NEXT: umull v1.2d, v0.2s, v1.2s
74-
; CHECK-NEXT: uzp2 v1.4s, v1.4s, v4.4s
75-
; CHECK-NEXT: neg v2.4s, v2.4s
76-
; CHECK-NEXT: ushl v1.4s, v1.4s, v2.4s
77-
; CHECK-NEXT: mls v0.4s, v1.4s, v3.4s
78-
; CHECK-NEXT: cmeq v0.4s, v0.4s, #0
79-
; CHECK-NEXT: mvn v0.16b, v0.16b
62+
; CHECK-NEXT: ldr q2, [x9, :lo12:.LCPI2_1]
63+
; CHECK-NEXT: mul v0.4s, v0.4s, v1.4s
64+
; CHECK-NEXT: cmhi v0.4s, v0.4s, v2.4s
8065
; CHECK-NEXT: movi v1.4s, #1
8166
; CHECK-NEXT: and v0.16b, v0.16b, v1.16b
8267
; CHECK-NEXT: ret
@@ -300,20 +285,11 @@ define <4 x i32> @test_urem_odd_one(<4 x i32> %X) nounwind {
300285
; CHECK: // %bb.0:
301286
; CHECK-NEXT: adrp x8, .LCPI10_0
302287
; CHECK-NEXT: ldr q1, [x8, :lo12:.LCPI10_0]
303-
; CHECK-NEXT: adrp x8, .LCPI10_1
304-
; CHECK-NEXT: ldr q2, [x8, :lo12:.LCPI10_1]
305-
; CHECK-NEXT: adrp x8, .LCPI10_2
306-
; CHECK-NEXT: ldr q3, [x8, :lo12:.LCPI10_2]
307-
; CHECK-NEXT: adrp x8, .LCPI10_3
308-
; CHECK-NEXT: umull2 v4.2d, v0.4s, v1.4s
309-
; CHECK-NEXT: umull v1.2d, v0.2s, v1.2s
310-
; CHECK-NEXT: uzp2 v1.4s, v1.4s, v4.4s
311-
; CHECK-NEXT: ldr q4, [x8, :lo12:.LCPI10_3]
312-
; CHECK-NEXT: neg v2.4s, v2.4s
313-
; CHECK-NEXT: ushl v1.4s, v1.4s, v2.4s
314-
; CHECK-NEXT: bsl v3.16b, v0.16b, v1.16b
315-
; CHECK-NEXT: mls v0.4s, v3.4s, v4.4s
316-
; CHECK-NEXT: cmeq v0.4s, v0.4s, #0
288+
; CHECK-NEXT: mov w8, #52429
289+
; CHECK-NEXT: movk w8, #52428, lsl #16
290+
; CHECK-NEXT: dup v2.4s, w8
291+
; CHECK-NEXT: mul v0.4s, v0.4s, v2.4s
292+
; CHECK-NEXT: cmhs v0.4s, v1.4s, v0.4s
317293
; CHECK-NEXT: movi v1.4s, #1
318294
; CHECK-NEXT: and v0.16b, v0.16b, v1.16b
319295
; CHECK-NEXT: ret
@@ -480,21 +456,11 @@ define <4 x i32> @test_urem_odd_allones_and_one(<4 x i32> %X) nounwind {
480456
; CHECK-LABEL: test_urem_odd_allones_and_one:
481457
; CHECK: // %bb.0:
482458
; CHECK-NEXT: adrp x8, .LCPI16_0
459+
; CHECK-NEXT: adrp x9, .LCPI16_1
483460
; CHECK-NEXT: ldr q1, [x8, :lo12:.LCPI16_0]
484-
; CHECK-NEXT: adrp x8, .LCPI16_1
485-
; CHECK-NEXT: ldr q2, [x8, :lo12:.LCPI16_1]
486-
; CHECK-NEXT: adrp x8, .LCPI16_2
487-
; CHECK-NEXT: ldr q3, [x8, :lo12:.LCPI16_2]
488-
; CHECK-NEXT: adrp x8, .LCPI16_3
489-
; CHECK-NEXT: umull2 v4.2d, v0.4s, v1.4s
490-
; CHECK-NEXT: umull v1.2d, v0.2s, v1.2s
491-
; CHECK-NEXT: uzp2 v1.4s, v1.4s, v4.4s
492-
; CHECK-NEXT: ldr q4, [x8, :lo12:.LCPI16_3]
493-
; CHECK-NEXT: neg v2.4s, v2.4s
494-
; CHECK-NEXT: ushl v1.4s, v1.4s, v2.4s
495-
; CHECK-NEXT: bsl v3.16b, v0.16b, v1.16b
496-
; CHECK-NEXT: mls v0.4s, v3.4s, v4.4s
497-
; CHECK-NEXT: cmeq v0.4s, v0.4s, #0
461+
; CHECK-NEXT: ldr q2, [x9, :lo12:.LCPI16_1]
462+
; CHECK-NEXT: mul v0.4s, v0.4s, v1.4s
463+
; CHECK-NEXT: cmhs v0.4s, v2.4s, v0.4s
498464
; CHECK-NEXT: movi v1.4s, #1
499465
; CHECK-NEXT: and v0.16b, v0.16b, v1.16b
500466
; CHECK-NEXT: ret

0 commit comments

Comments
 (0)