Skip to content

Commit fe9de4b

Browse files
committed
feat: move combineVSelectWithAllOnesOrZeros to DAGCombiner and x86 test
1 parent 092ef1d commit fe9de4b

File tree

3 files changed

+98
-73
lines changed

3 files changed

+98
-73
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12945,6 +12945,85 @@ SDValue DAGCombiner::visitVP_SELECT(SDNode *N) {
1294512945
return SDValue();
1294612946
}
1294712947

12948+
static SDValue combineVSelectWithAllOnesOrZeros(SDValue Cond, SDValue TVal,
12949+
SDValue FVal,
12950+
const TargetLowering &TLI,
12951+
SelectionDAG &DAG,
12952+
const SDLoc &DL) {
12953+
if (!TLI.isTypeLegal(TVal.getValueType()))
12954+
return SDValue();
12955+
12956+
EVT VT = TVal.getValueType();
12957+
EVT CondVT = Cond.getValueType();
12958+
12959+
assert(CondVT.isVector() && "Vector select expects a vector selector!");
12960+
12961+
// Classify TVal/FVal content
12962+
bool IsTAllZero = ISD::isBuildVectorAllZeros(TVal.getNode());
12963+
bool IsTAllOne = ISD::isBuildVectorAllOnes(TVal.getNode());
12964+
bool IsFAllZero = ISD::isBuildVectorAllZeros(FVal.getNode());
12965+
bool IsFAllOne = ISD::isBuildVectorAllOnes(FVal.getNode());
12966+
12967+
// no vselect(cond, 0/-1, X) or vselect(cond, X, 0/-1), return
12968+
if (!(IsTAllZero || IsTAllOne || IsFAllZero || IsFAllOne))
12969+
return SDValue();
12970+
12971+
// select Cond, 0, 0 → 0
12972+
if (IsTAllZero && IsFAllZero) {
12973+
return VT.isFloatingPoint() ? DAG.getConstantFP(0.0, DL, VT)
12974+
: DAG.getConstant(0, DL, VT);
12975+
}
12976+
12977+
// To use the condition operand as a bitwise mask, it must have elements that
12978+
// are the same size as the select elements. Ie, the condition operand must
12979+
// have already been promoted from the IR select condition type <N x i1>.
12980+
// Don't check if the types themselves are equal because that excludes
12981+
// vector floating-point selects.
12982+
if (CondVT.getScalarSizeInBits() != VT.getScalarSizeInBits())
12983+
return SDValue();
12984+
12985+
// Try inverting Cond and swapping T/F if it gives all-ones/all-zeros form
12986+
if (!IsTAllOne && !IsFAllZero && Cond.hasOneUse() &&
12987+
Cond.getOpcode() == ISD::SETCC &&
12988+
TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT) ==
12989+
CondVT) {
12990+
if (IsTAllZero || IsFAllOne) {
12991+
SDValue CC = Cond.getOperand(2);
12992+
ISD::CondCode InverseCC = ISD::getSetCCInverse(
12993+
cast<CondCodeSDNode>(CC)->get(), Cond.getOperand(0).getValueType());
12994+
Cond = DAG.getSetCC(DL, CondVT, Cond.getOperand(0), Cond.getOperand(1),
12995+
InverseCC);
12996+
std::swap(TVal, FVal);
12997+
std::swap(IsTAllOne, IsFAllOne);
12998+
std::swap(IsTAllZero, IsFAllZero);
12999+
}
13000+
}
13001+
13002+
// Cond value must be 'sign splat' to be converted to a logical op.
13003+
if (DAG.ComputeNumSignBits(Cond) != CondVT.getScalarSizeInBits())
13004+
return SDValue();
13005+
13006+
// select Cond, -1, 0 → bitcast Cond
13007+
if (IsTAllOne && IsFAllZero)
13008+
return DAG.getBitcast(VT, Cond);
13009+
13010+
// select Cond, -1, x → or Cond, x
13011+
if (IsTAllOne) {
13012+
SDValue X = DAG.getBitcast(CondVT, FVal);
13013+
SDValue Or = DAG.getNode(ISD::OR, DL, CondVT, Cond, X);
13014+
return DAG.getBitcast(VT, Or);
13015+
}
13016+
13017+
// select Cond, x, 0 → and Cond, x
13018+
if (IsFAllZero) {
13019+
SDValue X = DAG.getBitcast(CondVT, TVal);
13020+
SDValue And = DAG.getNode(ISD::AND, DL, CondVT, Cond, X);
13021+
return DAG.getBitcast(VT, And);
13022+
}
13023+
13024+
return SDValue();
13025+
}
13026+
1294813027
SDValue DAGCombiner::visitVSELECT(SDNode *N) {
1294913028
SDValue N0 = N->getOperand(0);
1295013029
SDValue N1 = N->getOperand(1);
@@ -13213,6 +13292,9 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
1321313292
if (SimplifyDemandedVectorElts(SDValue(N, 0)))
1321413293
return SDValue(N, 0);
1321513294

13295+
if (SDValue V = combineVSelectWithAllOnesOrZeros(N0, N1, N2, TLI, DAG, DL))
13296+
return V;
13297+
1321613298
return SDValue();
1321713299
}
1321813300

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 9 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -47264,13 +47264,14 @@ static SDValue combineToExtendBoolVectorInReg(
4726447264
DAG.getConstant(EltSizeInBits - 1, DL, VT));
4726547265
}
4726647266

47267-
/// If a vector select has an operand that is -1 or 0, try to simplify the
47267+
/// If a vector select has an left operand that is 0, try to simplify the
4726847268
/// select to a bitwise logic operation.
47269-
/// TODO: Move to DAGCombiner, possibly using TargetLowering::hasAndNot()?
47270-
static SDValue
47271-
combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG, const SDLoc &DL,
47272-
TargetLowering::DAGCombinerInfo &DCI,
47273-
const X86Subtarget &Subtarget) {
47269+
/// TODO: Move to DAGCombiner.combineVSelectWithAllOnesOrZeros, possibly using
47270+
/// TargetLowering::hasAndNot()?
47271+
static SDValue combineVSelectWithLastZeros(SDNode *N, SelectionDAG &DAG,
47272+
const SDLoc &DL,
47273+
TargetLowering::DAGCombinerInfo &DCI,
47274+
const X86Subtarget &Subtarget) {
4727447275
SDValue Cond = N->getOperand(0);
4727547276
SDValue LHS = N->getOperand(1);
4727647277
SDValue RHS = N->getOperand(2);
@@ -47283,20 +47284,6 @@ combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG, const SDLoc &DL,
4728347284

4728447285
assert(CondVT.isVector() && "Vector select expects a vector selector!");
4728547286

47286-
// TODO: Use isNullOrNullSplat() to distinguish constants with undefs?
47287-
// TODO: Can we assert that both operands are not zeros (because that should
47288-
// get simplified at node creation time)?
47289-
bool TValIsAllZeros = ISD::isBuildVectorAllZeros(LHS.getNode());
47290-
bool FValIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode());
47291-
47292-
// If both inputs are 0/undef, create a complete zero vector.
47293-
// FIXME: As noted above this should be handled by DAGCombiner/getNode.
47294-
if (TValIsAllZeros && FValIsAllZeros) {
47295-
if (VT.isFloatingPoint())
47296-
return DAG.getConstantFP(0.0, DL, VT);
47297-
return DAG.getConstant(0, DL, VT);
47298-
}
47299-
4730047287
// To use the condition operand as a bitwise mask, it must have elements that
4730147288
// are the same size as the select elements. Ie, the condition operand must
4730247289
// have already been promoted from the IR select condition type <N x i1>.
@@ -47305,56 +47292,15 @@ combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG, const SDLoc &DL,
4730547292
if (CondVT.getScalarSizeInBits() != VT.getScalarSizeInBits())
4730647293
return SDValue();
4730747294

47308-
// Try to invert the condition if true value is not all 1s and false value is
47309-
// not all 0s. Only do this if the condition has one use.
47310-
bool TValIsAllOnes = ISD::isBuildVectorAllOnes(LHS.getNode());
47311-
if (!TValIsAllOnes && !FValIsAllZeros && Cond.hasOneUse() &&
47312-
// Check if the selector will be produced by CMPP*/PCMP*.
47313-
Cond.getOpcode() == ISD::SETCC &&
47314-
// Check if SETCC has already been promoted.
47315-
TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT) ==
47316-
CondVT) {
47317-
bool FValIsAllOnes = ISD::isBuildVectorAllOnes(RHS.getNode());
47318-
47319-
if (TValIsAllZeros || FValIsAllOnes) {
47320-
SDValue CC = Cond.getOperand(2);
47321-
ISD::CondCode NewCC = ISD::getSetCCInverse(
47322-
cast<CondCodeSDNode>(CC)->get(), Cond.getOperand(0).getValueType());
47323-
Cond = DAG.getSetCC(DL, CondVT, Cond.getOperand(0), Cond.getOperand(1),
47324-
NewCC);
47325-
std::swap(LHS, RHS);
47326-
TValIsAllOnes = FValIsAllOnes;
47327-
FValIsAllZeros = TValIsAllZeros;
47328-
}
47329-
}
47330-
4733147295
// Cond value must be 'sign splat' to be converted to a logical op.
4733247296
if (DAG.ComputeNumSignBits(Cond) != CondVT.getScalarSizeInBits())
4733347297
return SDValue();
4733447298

47335-
// vselect Cond, 111..., 000... -> Cond
47336-
if (TValIsAllOnes && FValIsAllZeros)
47337-
return DAG.getBitcast(VT, Cond);
47338-
4733947299
if (!TLI.isTypeLegal(CondVT))
4734047300
return SDValue();
4734147301

47342-
// vselect Cond, 111..., X -> or Cond, X
47343-
if (TValIsAllOnes) {
47344-
SDValue CastRHS = DAG.getBitcast(CondVT, RHS);
47345-
SDValue Or = DAG.getNode(ISD::OR, DL, CondVT, Cond, CastRHS);
47346-
return DAG.getBitcast(VT, Or);
47347-
}
47348-
47349-
// vselect Cond, X, 000... -> and Cond, X
47350-
if (FValIsAllZeros) {
47351-
SDValue CastLHS = DAG.getBitcast(CondVT, LHS);
47352-
SDValue And = DAG.getNode(ISD::AND, DL, CondVT, Cond, CastLHS);
47353-
return DAG.getBitcast(VT, And);
47354-
}
47355-
4735647302
// vselect Cond, 000..., X -> andn Cond, X
47357-
if (TValIsAllZeros) {
47303+
if (ISD::isBuildVectorAllZeros(LHS.getNode())) {
4735847304
SDValue CastRHS = DAG.getBitcast(CondVT, RHS);
4735947305
SDValue AndN;
4736047306
// The canonical form differs for i1 vectors - x86andnp is not used
@@ -48117,7 +48063,7 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
4811748063
if (!TLI.isTypeLegal(VT) || isSoftF16(VT, Subtarget))
4811848064
return SDValue();
4811948065

48120-
if (SDValue V = combineVSelectWithAllOnesOrZeros(N, DAG, DL, DCI, Subtarget))
48066+
if (SDValue V = combineVSelectWithLastZeros(N, DAG, DL, DCI, Subtarget))
4812148067
return V;
4812248068

4812348069
if (SDValue V = combineVSelectToBLENDV(N, DAG, DL, DCI, Subtarget))

llvm/test/CodeGen/X86/urem-seteq-vec-tautological.ll

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,9 @@ define <4 x i1> @t1_all_odd_ne(<4 x i32> %X) nounwind {
7777
; CHECK-SSE2-LABEL: t1_all_odd_ne:
7878
; CHECK-SSE2: # %bb.0:
7979
; CHECK-SSE2-NEXT: pmuludq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
80-
; CHECK-SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
8180
; CHECK-SSE2-NEXT: pxor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
8281
; CHECK-SSE2-NEXT: pcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
83-
; CHECK-SSE2-NEXT: pcmpeqd %xmm1, %xmm1
84-
; CHECK-SSE2-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
82+
; CHECK-SSE2-NEXT: por {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
8583
; CHECK-SSE2-NEXT: retq
8684
;
8785
; CHECK-SSE41-LABEL: t1_all_odd_ne:
@@ -92,7 +90,7 @@ define <4 x i1> @t1_all_odd_ne(<4 x i32> %X) nounwind {
9290
; CHECK-SSE41-NEXT: pcmpeqd %xmm1, %xmm0
9391
; CHECK-SSE41-NEXT: pcmpeqd %xmm1, %xmm1
9492
; CHECK-SSE41-NEXT: pxor %xmm1, %xmm0
95-
; CHECK-SSE41-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5],xmm1[6,7]
93+
; CHECK-SSE41-NEXT: por {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
9694
; CHECK-SSE41-NEXT: retq
9795
;
9896
; CHECK-AVX1-LABEL: t1_all_odd_ne:
@@ -102,7 +100,7 @@ define <4 x i1> @t1_all_odd_ne(<4 x i32> %X) nounwind {
102100
; CHECK-AVX1-NEXT: vpcmpeqd %xmm1, %xmm0, %xmm0
103101
; CHECK-AVX1-NEXT: vpcmpeqd %xmm1, %xmm1, %xmm1
104102
; CHECK-AVX1-NEXT: vpxor %xmm1, %xmm0, %xmm0
105-
; CHECK-AVX1-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5],xmm1[6,7]
103+
; CHECK-AVX1-NEXT: vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
106104
; CHECK-AVX1-NEXT: retq
107105
;
108106
; CHECK-AVX2-LABEL: t1_all_odd_ne:
@@ -113,17 +111,16 @@ define <4 x i1> @t1_all_odd_ne(<4 x i32> %X) nounwind {
113111
; CHECK-AVX2-NEXT: vpcmpeqd %xmm1, %xmm0, %xmm0
114112
; CHECK-AVX2-NEXT: vpcmpeqd %xmm1, %xmm1, %xmm1
115113
; CHECK-AVX2-NEXT: vpxor %xmm1, %xmm0, %xmm0
116-
; CHECK-AVX2-NEXT: vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
114+
; CHECK-AVX2-NEXT: vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
117115
; CHECK-AVX2-NEXT: retq
118116
;
119117
; CHECK-AVX512VL-LABEL: t1_all_odd_ne:
120118
; CHECK-AVX512VL: # %bb.0:
121119
; CHECK-AVX512VL-NEXT: vpmulld {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to4}, %xmm0, %xmm0
122120
; CHECK-AVX512VL-NEXT: vpminud {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1
123-
; CHECK-AVX512VL-NEXT: vpcmpeqd %xmm1, %xmm0, %xmm0
124-
; CHECK-AVX512VL-NEXT: vpternlogq {{.*#+}} xmm0 = ~xmm0
125-
; CHECK-AVX512VL-NEXT: vpcmpeqd %xmm1, %xmm1, %xmm1
126-
; CHECK-AVX512VL-NEXT: vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
121+
; CHECK-AVX512VL-NEXT: vpcmpeqd %xmm1, %xmm0, %xmm1
122+
; CHECK-AVX512VL-NEXT: vpcmpeqd %xmm0, %xmm0, %xmm0
123+
; CHECK-AVX512VL-NEXT: vpternlogq {{.*#+}} xmm0 = m64bcst | (xmm0 ^ xmm1)
127124
; CHECK-AVX512VL-NEXT: retq
128125
%urem = urem <4 x i32> %X, <i32 3, i32 1, i32 1, i32 9>
129126
%cmp = icmp ne <4 x i32> %urem, <i32 0, i32 42, i32 0, i32 42>

0 commit comments

Comments
 (0)