Skip to content

Commit 8e38989

Browse files
committed
feat: move combineVSelectWithAllOnesOrZeros to DAGCombiner and x86 test
1 parent 67be4fe commit 8e38989

File tree

3 files changed

+98
-73
lines changed

3 files changed

+98
-73
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13069,6 +13069,85 @@ SDValue DAGCombiner::visitVP_SELECT(SDNode *N) {
1306913069
return SDValue();
1307013070
}
1307113071

13072+
static SDValue combineVSelectWithAllOnesOrZeros(SDValue Cond, SDValue TVal,
13073+
SDValue FVal,
13074+
const TargetLowering &TLI,
13075+
SelectionDAG &DAG,
13076+
const SDLoc &DL) {
13077+
if (!TLI.isTypeLegal(TVal.getValueType()))
13078+
return SDValue();
13079+
13080+
EVT VT = TVal.getValueType();
13081+
EVT CondVT = Cond.getValueType();
13082+
13083+
assert(CondVT.isVector() && "Vector select expects a vector selector!");
13084+
13085+
// Classify TVal/FVal content
13086+
bool IsTAllZero = ISD::isBuildVectorAllZeros(TVal.getNode());
13087+
bool IsTAllOne = ISD::isBuildVectorAllOnes(TVal.getNode());
13088+
bool IsFAllZero = ISD::isBuildVectorAllZeros(FVal.getNode());
13089+
bool IsFAllOne = ISD::isBuildVectorAllOnes(FVal.getNode());
13090+
13091+
// no vselect(cond, 0/-1, X) or vselect(cond, X, 0/-1), return
13092+
if (!(IsTAllZero || IsTAllOne || IsFAllZero || IsFAllOne))
13093+
return SDValue();
13094+
13095+
// select Cond, 0, 0 → 0
13096+
if (IsTAllZero && IsFAllZero) {
13097+
return VT.isFloatingPoint() ? DAG.getConstantFP(0.0, DL, VT)
13098+
: DAG.getConstant(0, DL, VT);
13099+
}
13100+
13101+
// To use the condition operand as a bitwise mask, it must have elements that
13102+
// are the same size as the select elements. Ie, the condition operand must
13103+
// have already been promoted from the IR select condition type <N x i1>.
13104+
// Don't check if the types themselves are equal because that excludes
13105+
// vector floating-point selects.
13106+
if (CondVT.getScalarSizeInBits() != VT.getScalarSizeInBits())
13107+
return SDValue();
13108+
13109+
// Try inverting Cond and swapping T/F if it gives all-ones/all-zeros form
13110+
if (!IsTAllOne && !IsFAllZero && Cond.hasOneUse() &&
13111+
Cond.getOpcode() == ISD::SETCC &&
13112+
TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT) ==
13113+
CondVT) {
13114+
if (IsTAllZero || IsFAllOne) {
13115+
SDValue CC = Cond.getOperand(2);
13116+
ISD::CondCode InverseCC = ISD::getSetCCInverse(
13117+
cast<CondCodeSDNode>(CC)->get(), Cond.getOperand(0).getValueType());
13118+
Cond = DAG.getSetCC(DL, CondVT, Cond.getOperand(0), Cond.getOperand(1),
13119+
InverseCC);
13120+
std::swap(TVal, FVal);
13121+
std::swap(IsTAllOne, IsFAllOne);
13122+
std::swap(IsTAllZero, IsFAllZero);
13123+
}
13124+
}
13125+
13126+
// Cond value must be 'sign splat' to be converted to a logical op.
13127+
if (DAG.ComputeNumSignBits(Cond) != CondVT.getScalarSizeInBits())
13128+
return SDValue();
13129+
13130+
// select Cond, -1, 0 → bitcast Cond
13131+
if (IsTAllOne && IsFAllZero)
13132+
return DAG.getBitcast(VT, Cond);
13133+
13134+
// select Cond, -1, x → or Cond, x
13135+
if (IsTAllOne) {
13136+
SDValue X = DAG.getBitcast(CondVT, FVal);
13137+
SDValue Or = DAG.getNode(ISD::OR, DL, CondVT, Cond, X);
13138+
return DAG.getBitcast(VT, Or);
13139+
}
13140+
13141+
// select Cond, x, 0 → and Cond, x
13142+
if (IsFAllZero) {
13143+
SDValue X = DAG.getBitcast(CondVT, TVal);
13144+
SDValue And = DAG.getNode(ISD::AND, DL, CondVT, Cond, X);
13145+
return DAG.getBitcast(VT, And);
13146+
}
13147+
13148+
return SDValue();
13149+
}
13150+
1307213151
SDValue DAGCombiner::visitVSELECT(SDNode *N) {
1307313152
SDValue N0 = N->getOperand(0);
1307413153
SDValue N1 = N->getOperand(1);
@@ -13337,6 +13416,9 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
1333713416
if (SimplifyDemandedVectorElts(SDValue(N, 0)))
1333813417
return SDValue(N, 0);
1333913418

13419+
if (SDValue V = combineVSelectWithAllOnesOrZeros(N0, N1, N2, TLI, DAG, DL))
13420+
return V;
13421+
1334013422
return SDValue();
1334113423
}
1334213424

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 9 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -47256,13 +47256,14 @@ static SDValue combineToExtendBoolVectorInReg(
4725647256
DAG.getConstant(EltSizeInBits - 1, DL, VT));
4725747257
}
4725847258

47259-
/// If a vector select has an operand that is -1 or 0, try to simplify the
47259+
/// If a vector select has an left operand that is 0, try to simplify the
4726047260
/// select to a bitwise logic operation.
47261-
/// TODO: Move to DAGCombiner, possibly using TargetLowering::hasAndNot()?
47262-
static SDValue
47263-
combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG, const SDLoc &DL,
47264-
TargetLowering::DAGCombinerInfo &DCI,
47265-
const X86Subtarget &Subtarget) {
47261+
/// TODO: Move to DAGCombiner.combineVSelectWithAllOnesOrZeros, possibly using
47262+
/// TargetLowering::hasAndNot()?
47263+
static SDValue combineVSelectWithLastZeros(SDNode *N, SelectionDAG &DAG,
47264+
const SDLoc &DL,
47265+
TargetLowering::DAGCombinerInfo &DCI,
47266+
const X86Subtarget &Subtarget) {
4726647267
SDValue Cond = N->getOperand(0);
4726747268
SDValue LHS = N->getOperand(1);
4726847269
SDValue RHS = N->getOperand(2);
@@ -47275,20 +47276,6 @@ combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG, const SDLoc &DL,
4727547276

4727647277
assert(CondVT.isVector() && "Vector select expects a vector selector!");
4727747278

47278-
// TODO: Use isNullOrNullSplat() to distinguish constants with undefs?
47279-
// TODO: Can we assert that both operands are not zeros (because that should
47280-
// get simplified at node creation time)?
47281-
bool TValIsAllZeros = ISD::isBuildVectorAllZeros(LHS.getNode());
47282-
bool FValIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode());
47283-
47284-
// If both inputs are 0/undef, create a complete zero vector.
47285-
// FIXME: As noted above this should be handled by DAGCombiner/getNode.
47286-
if (TValIsAllZeros && FValIsAllZeros) {
47287-
if (VT.isFloatingPoint())
47288-
return DAG.getConstantFP(0.0, DL, VT);
47289-
return DAG.getConstant(0, DL, VT);
47290-
}
47291-
4729247279
// To use the condition operand as a bitwise mask, it must have elements that
4729347280
// are the same size as the select elements. Ie, the condition operand must
4729447281
// have already been promoted from the IR select condition type <N x i1>.
@@ -47297,56 +47284,15 @@ combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG, const SDLoc &DL,
4729747284
if (CondVT.getScalarSizeInBits() != VT.getScalarSizeInBits())
4729847285
return SDValue();
4729947286

47300-
// Try to invert the condition if true value is not all 1s and false value is
47301-
// not all 0s. Only do this if the condition has one use.
47302-
bool TValIsAllOnes = ISD::isBuildVectorAllOnes(LHS.getNode());
47303-
if (!TValIsAllOnes && !FValIsAllZeros && Cond.hasOneUse() &&
47304-
// Check if the selector will be produced by CMPP*/PCMP*.
47305-
Cond.getOpcode() == ISD::SETCC &&
47306-
// Check if SETCC has already been promoted.
47307-
TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT) ==
47308-
CondVT) {
47309-
bool FValIsAllOnes = ISD::isBuildVectorAllOnes(RHS.getNode());
47310-
47311-
if (TValIsAllZeros || FValIsAllOnes) {
47312-
SDValue CC = Cond.getOperand(2);
47313-
ISD::CondCode NewCC = ISD::getSetCCInverse(
47314-
cast<CondCodeSDNode>(CC)->get(), Cond.getOperand(0).getValueType());
47315-
Cond = DAG.getSetCC(DL, CondVT, Cond.getOperand(0), Cond.getOperand(1),
47316-
NewCC);
47317-
std::swap(LHS, RHS);
47318-
TValIsAllOnes = FValIsAllOnes;
47319-
FValIsAllZeros = TValIsAllZeros;
47320-
}
47321-
}
47322-
4732347287
// Cond value must be 'sign splat' to be converted to a logical op.
4732447288
if (DAG.ComputeNumSignBits(Cond) != CondVT.getScalarSizeInBits())
4732547289
return SDValue();
4732647290

47327-
// vselect Cond, 111..., 000... -> Cond
47328-
if (TValIsAllOnes && FValIsAllZeros)
47329-
return DAG.getBitcast(VT, Cond);
47330-
4733147291
if (!TLI.isTypeLegal(CondVT))
4733247292
return SDValue();
4733347293

47334-
// vselect Cond, 111..., X -> or Cond, X
47335-
if (TValIsAllOnes) {
47336-
SDValue CastRHS = DAG.getBitcast(CondVT, RHS);
47337-
SDValue Or = DAG.getNode(ISD::OR, DL, CondVT, Cond, CastRHS);
47338-
return DAG.getBitcast(VT, Or);
47339-
}
47340-
47341-
// vselect Cond, X, 000... -> and Cond, X
47342-
if (FValIsAllZeros) {
47343-
SDValue CastLHS = DAG.getBitcast(CondVT, LHS);
47344-
SDValue And = DAG.getNode(ISD::AND, DL, CondVT, Cond, CastLHS);
47345-
return DAG.getBitcast(VT, And);
47346-
}
47347-
4734847294
// vselect Cond, 000..., X -> andn Cond, X
47349-
if (TValIsAllZeros) {
47295+
if (ISD::isBuildVectorAllZeros(LHS.getNode())) {
4735047296
SDValue CastRHS = DAG.getBitcast(CondVT, RHS);
4735147297
SDValue AndN;
4735247298
// The canonical form differs for i1 vectors - x86andnp is not used
@@ -48107,7 +48053,7 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
4810748053
if (!TLI.isTypeLegal(VT) || isSoftF16(VT, Subtarget))
4810848054
return SDValue();
4810948055

48110-
if (SDValue V = combineVSelectWithAllOnesOrZeros(N, DAG, DL, DCI, Subtarget))
48056+
if (SDValue V = combineVSelectWithLastZeros(N, DAG, DL, DCI, Subtarget))
4811148057
return V;
4811248058

4811348059
if (SDValue V = combineVSelectToBLENDV(N, DAG, DL, DCI, Subtarget))

llvm/test/CodeGen/X86/urem-seteq-vec-tautological.ll

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,9 @@ define <4 x i1> @t1_all_odd_ne(<4 x i32> %X) nounwind {
7777
; CHECK-SSE2-LABEL: t1_all_odd_ne:
7878
; CHECK-SSE2: # %bb.0:
7979
; CHECK-SSE2-NEXT: pmuludq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
80-
; CHECK-SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
8180
; CHECK-SSE2-NEXT: pxor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
8281
; CHECK-SSE2-NEXT: pcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
83-
; CHECK-SSE2-NEXT: pcmpeqd %xmm1, %xmm1
84-
; CHECK-SSE2-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
82+
; CHECK-SSE2-NEXT: por {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
8583
; CHECK-SSE2-NEXT: retq
8684
;
8785
; CHECK-SSE41-LABEL: t1_all_odd_ne:
@@ -92,7 +90,7 @@ define <4 x i1> @t1_all_odd_ne(<4 x i32> %X) nounwind {
9290
; CHECK-SSE41-NEXT: pcmpeqd %xmm1, %xmm0
9391
; CHECK-SSE41-NEXT: pcmpeqd %xmm1, %xmm1
9492
; CHECK-SSE41-NEXT: pxor %xmm1, %xmm0
95-
; CHECK-SSE41-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5],xmm1[6,7]
93+
; CHECK-SSE41-NEXT: por {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
9694
; CHECK-SSE41-NEXT: retq
9795
;
9896
; CHECK-AVX1-LABEL: t1_all_odd_ne:
@@ -102,7 +100,7 @@ define <4 x i1> @t1_all_odd_ne(<4 x i32> %X) nounwind {
102100
; CHECK-AVX1-NEXT: vpcmpeqd %xmm1, %xmm0, %xmm0
103101
; CHECK-AVX1-NEXT: vpcmpeqd %xmm1, %xmm1, %xmm1
104102
; CHECK-AVX1-NEXT: vpxor %xmm1, %xmm0, %xmm0
105-
; CHECK-AVX1-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5],xmm1[6,7]
103+
; CHECK-AVX1-NEXT: vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
106104
; CHECK-AVX1-NEXT: retq
107105
;
108106
; CHECK-AVX2-LABEL: t1_all_odd_ne:
@@ -113,17 +111,16 @@ define <4 x i1> @t1_all_odd_ne(<4 x i32> %X) nounwind {
113111
; CHECK-AVX2-NEXT: vpcmpeqd %xmm1, %xmm0, %xmm0
114112
; CHECK-AVX2-NEXT: vpcmpeqd %xmm1, %xmm1, %xmm1
115113
; CHECK-AVX2-NEXT: vpxor %xmm1, %xmm0, %xmm0
116-
; CHECK-AVX2-NEXT: vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
114+
; CHECK-AVX2-NEXT: vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
117115
; CHECK-AVX2-NEXT: retq
118116
;
119117
; CHECK-AVX512VL-LABEL: t1_all_odd_ne:
120118
; CHECK-AVX512VL: # %bb.0:
121119
; CHECK-AVX512VL-NEXT: vpmulld {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to4}, %xmm0, %xmm0
122120
; CHECK-AVX512VL-NEXT: vpminud {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1
123-
; CHECK-AVX512VL-NEXT: vpcmpeqd %xmm1, %xmm0, %xmm0
124-
; CHECK-AVX512VL-NEXT: vpternlogq {{.*#+}} xmm0 = ~xmm0
125-
; CHECK-AVX512VL-NEXT: vpcmpeqd %xmm1, %xmm1, %xmm1
126-
; CHECK-AVX512VL-NEXT: vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
121+
; CHECK-AVX512VL-NEXT: vpcmpeqd %xmm1, %xmm0, %xmm1
122+
; CHECK-AVX512VL-NEXT: vpcmpeqd %xmm0, %xmm0, %xmm0
123+
; CHECK-AVX512VL-NEXT: vpternlogq {{.*#+}} xmm0 = m64bcst | (xmm0 ^ xmm1)
127124
; CHECK-AVX512VL-NEXT: retq
128125
%urem = urem <4 x i32> %X, <i32 3, i32 1, i32 1, i32 9>
129126
%cmp = icmp ne <4 x i32> %urem, <i32 0, i32 42, i32 0, i32 42>

0 commit comments

Comments
 (0)