Skip to content

Commit 4103b4c

Browse files
committed
feat: move combineVSelectWithAllOnesOrZeros to DAGCombiner and x86 test
1 parent 54953b9 commit 4103b4c

File tree

3 files changed

+98
-73
lines changed

3 files changed

+98
-73
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12987,6 +12987,85 @@ SDValue DAGCombiner::visitVP_SELECT(SDNode *N) {
1298712987
return SDValue();
1298812988
}
1298912989

12990+
static SDValue combineVSelectWithAllOnesOrZeros(SDValue Cond, SDValue TVal,
12991+
SDValue FVal,
12992+
const TargetLowering &TLI,
12993+
SelectionDAG &DAG,
12994+
const SDLoc &DL) {
12995+
if (!TLI.isTypeLegal(TVal.getValueType()))
12996+
return SDValue();
12997+
12998+
EVT VT = TVal.getValueType();
12999+
EVT CondVT = Cond.getValueType();
13000+
13001+
assert(CondVT.isVector() && "Vector select expects a vector selector!");
13002+
13003+
// Classify TVal/FVal content
13004+
bool IsTAllZero = ISD::isBuildVectorAllZeros(TVal.getNode());
13005+
bool IsTAllOne = ISD::isBuildVectorAllOnes(TVal.getNode());
13006+
bool IsFAllZero = ISD::isBuildVectorAllZeros(FVal.getNode());
13007+
bool IsFAllOne = ISD::isBuildVectorAllOnes(FVal.getNode());
13008+
13009+
// no vselect(cond, 0/-1, X) or vselect(cond, X, 0/-1), return
13010+
if (!(IsTAllZero || IsTAllOne || IsFAllZero || IsFAllOne))
13011+
return SDValue();
13012+
13013+
// select Cond, 0, 0 → 0
13014+
if (IsTAllZero && IsFAllZero) {
13015+
return VT.isFloatingPoint() ? DAG.getConstantFP(0.0, DL, VT)
13016+
: DAG.getConstant(0, DL, VT);
13017+
}
13018+
13019+
// To use the condition operand as a bitwise mask, it must have elements that
13020+
// are the same size as the select elements. Ie, the condition operand must
13021+
// have already been promoted from the IR select condition type <N x i1>.
13022+
// Don't check if the types themselves are equal because that excludes
13023+
// vector floating-point selects.
13024+
if (CondVT.getScalarSizeInBits() != VT.getScalarSizeInBits())
13025+
return SDValue();
13026+
13027+
// Try inverting Cond and swapping T/F if it gives all-ones/all-zeros form
13028+
if (!IsTAllOne && !IsFAllZero && Cond.hasOneUse() &&
13029+
Cond.getOpcode() == ISD::SETCC &&
13030+
TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT) ==
13031+
CondVT) {
13032+
if (IsTAllZero || IsFAllOne) {
13033+
SDValue CC = Cond.getOperand(2);
13034+
ISD::CondCode InverseCC = ISD::getSetCCInverse(
13035+
cast<CondCodeSDNode>(CC)->get(), Cond.getOperand(0).getValueType());
13036+
Cond = DAG.getSetCC(DL, CondVT, Cond.getOperand(0), Cond.getOperand(1),
13037+
InverseCC);
13038+
std::swap(TVal, FVal);
13039+
std::swap(IsTAllOne, IsFAllOne);
13040+
std::swap(IsTAllZero, IsFAllZero);
13041+
}
13042+
}
13043+
13044+
// Cond value must be 'sign splat' to be converted to a logical op.
13045+
if (DAG.ComputeNumSignBits(Cond) != CondVT.getScalarSizeInBits())
13046+
return SDValue();
13047+
13048+
// select Cond, -1, 0 → bitcast Cond
13049+
if (IsTAllOne && IsFAllZero)
13050+
return DAG.getBitcast(VT, Cond);
13051+
13052+
// select Cond, -1, x → or Cond, x
13053+
if (IsTAllOne) {
13054+
SDValue X = DAG.getBitcast(CondVT, FVal);
13055+
SDValue Or = DAG.getNode(ISD::OR, DL, CondVT, Cond, X);
13056+
return DAG.getBitcast(VT, Or);
13057+
}
13058+
13059+
// select Cond, x, 0 → and Cond, x
13060+
if (IsFAllZero) {
13061+
SDValue X = DAG.getBitcast(CondVT, TVal);
13062+
SDValue And = DAG.getNode(ISD::AND, DL, CondVT, Cond, X);
13063+
return DAG.getBitcast(VT, And);
13064+
}
13065+
13066+
return SDValue();
13067+
}
13068+
1299013069
SDValue DAGCombiner::visitVSELECT(SDNode *N) {
1299113070
SDValue N0 = N->getOperand(0);
1299213071
SDValue N1 = N->getOperand(1);
@@ -13255,6 +13334,9 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
1325513334
if (SimplifyDemandedVectorElts(SDValue(N, 0)))
1325613335
return SDValue(N, 0);
1325713336

13337+
if (SDValue V = combineVSelectWithAllOnesOrZeros(N0, N1, N2, TLI, DAG, DL))
13338+
return V;
13339+
1325813340
return SDValue();
1325913341
}
1326013342

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 9 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -47264,13 +47264,14 @@ static SDValue combineToExtendBoolVectorInReg(
4726447264
DAG.getConstant(EltSizeInBits - 1, DL, VT));
4726547265
}
4726647266

47267-
/// If a vector select has an operand that is -1 or 0, try to simplify the
47267+
/// If a vector select has an left operand that is 0, try to simplify the
4726847268
/// select to a bitwise logic operation.
47269-
/// TODO: Move to DAGCombiner, possibly using TargetLowering::hasAndNot()?
47270-
static SDValue
47271-
combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG, const SDLoc &DL,
47272-
TargetLowering::DAGCombinerInfo &DCI,
47273-
const X86Subtarget &Subtarget) {
47269+
/// TODO: Move to DAGCombiner.combineVSelectWithAllOnesOrZeros, possibly using
47270+
/// TargetLowering::hasAndNot()?
47271+
static SDValue combineVSelectWithLastZeros(SDNode *N, SelectionDAG &DAG,
47272+
const SDLoc &DL,
47273+
TargetLowering::DAGCombinerInfo &DCI,
47274+
const X86Subtarget &Subtarget) {
4727447275
SDValue Cond = N->getOperand(0);
4727547276
SDValue LHS = N->getOperand(1);
4727647277
SDValue RHS = N->getOperand(2);
@@ -47283,20 +47284,6 @@ combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG, const SDLoc &DL,
4728347284

4728447285
assert(CondVT.isVector() && "Vector select expects a vector selector!");
4728547286

47286-
// TODO: Use isNullOrNullSplat() to distinguish constants with undefs?
47287-
// TODO: Can we assert that both operands are not zeros (because that should
47288-
// get simplified at node creation time)?
47289-
bool TValIsAllZeros = ISD::isBuildVectorAllZeros(LHS.getNode());
47290-
bool FValIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode());
47291-
47292-
// If both inputs are 0/undef, create a complete zero vector.
47293-
// FIXME: As noted above this should be handled by DAGCombiner/getNode.
47294-
if (TValIsAllZeros && FValIsAllZeros) {
47295-
if (VT.isFloatingPoint())
47296-
return DAG.getConstantFP(0.0, DL, VT);
47297-
return DAG.getConstant(0, DL, VT);
47298-
}
47299-
4730047287
// To use the condition operand as a bitwise mask, it must have elements that
4730147288
// are the same size as the select elements. Ie, the condition operand must
4730247289
// have already been promoted from the IR select condition type <N x i1>.
@@ -47305,56 +47292,15 @@ combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG, const SDLoc &DL,
4730547292
if (CondVT.getScalarSizeInBits() != VT.getScalarSizeInBits())
4730647293
return SDValue();
4730747294

47308-
// Try to invert the condition if true value is not all 1s and false value is
47309-
// not all 0s. Only do this if the condition has one use.
47310-
bool TValIsAllOnes = ISD::isBuildVectorAllOnes(LHS.getNode());
47311-
if (!TValIsAllOnes && !FValIsAllZeros && Cond.hasOneUse() &&
47312-
// Check if the selector will be produced by CMPP*/PCMP*.
47313-
Cond.getOpcode() == ISD::SETCC &&
47314-
// Check if SETCC has already been promoted.
47315-
TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT) ==
47316-
CondVT) {
47317-
bool FValIsAllOnes = ISD::isBuildVectorAllOnes(RHS.getNode());
47318-
47319-
if (TValIsAllZeros || FValIsAllOnes) {
47320-
SDValue CC = Cond.getOperand(2);
47321-
ISD::CondCode NewCC = ISD::getSetCCInverse(
47322-
cast<CondCodeSDNode>(CC)->get(), Cond.getOperand(0).getValueType());
47323-
Cond = DAG.getSetCC(DL, CondVT, Cond.getOperand(0), Cond.getOperand(1),
47324-
NewCC);
47325-
std::swap(LHS, RHS);
47326-
TValIsAllOnes = FValIsAllOnes;
47327-
FValIsAllZeros = TValIsAllZeros;
47328-
}
47329-
}
47330-
4733147295
// Cond value must be 'sign splat' to be converted to a logical op.
4733247296
if (DAG.ComputeNumSignBits(Cond) != CondVT.getScalarSizeInBits())
4733347297
return SDValue();
4733447298

47335-
// vselect Cond, 111..., 000... -> Cond
47336-
if (TValIsAllOnes && FValIsAllZeros)
47337-
return DAG.getBitcast(VT, Cond);
47338-
4733947299
if (!TLI.isTypeLegal(CondVT))
4734047300
return SDValue();
4734147301

47342-
// vselect Cond, 111..., X -> or Cond, X
47343-
if (TValIsAllOnes) {
47344-
SDValue CastRHS = DAG.getBitcast(CondVT, RHS);
47345-
SDValue Or = DAG.getNode(ISD::OR, DL, CondVT, Cond, CastRHS);
47346-
return DAG.getBitcast(VT, Or);
47347-
}
47348-
47349-
// vselect Cond, X, 000... -> and Cond, X
47350-
if (FValIsAllZeros) {
47351-
SDValue CastLHS = DAG.getBitcast(CondVT, LHS);
47352-
SDValue And = DAG.getNode(ISD::AND, DL, CondVT, Cond, CastLHS);
47353-
return DAG.getBitcast(VT, And);
47354-
}
47355-
4735647302
// vselect Cond, 000..., X -> andn Cond, X
47357-
if (TValIsAllZeros) {
47303+
if (ISD::isBuildVectorAllZeros(LHS.getNode())) {
4735847304
SDValue CastRHS = DAG.getBitcast(CondVT, RHS);
4735947305
SDValue AndN;
4736047306
// The canonical form differs for i1 vectors - x86andnp is not used
@@ -48115,7 +48061,7 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
4811548061
if (!TLI.isTypeLegal(VT) || isSoftF16(VT, Subtarget))
4811648062
return SDValue();
4811748063

48118-
if (SDValue V = combineVSelectWithAllOnesOrZeros(N, DAG, DL, DCI, Subtarget))
48064+
if (SDValue V = combineVSelectWithLastZeros(N, DAG, DL, DCI, Subtarget))
4811948065
return V;
4812048066

4812148067
if (SDValue V = combineVSelectToBLENDV(N, DAG, DL, DCI, Subtarget))

llvm/test/CodeGen/X86/urem-seteq-vec-tautological.ll

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,9 @@ define <4 x i1> @t1_all_odd_ne(<4 x i32> %X) nounwind {
7777
; CHECK-SSE2-LABEL: t1_all_odd_ne:
7878
; CHECK-SSE2: # %bb.0:
7979
; CHECK-SSE2-NEXT: pmuludq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
80-
; CHECK-SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
8180
; CHECK-SSE2-NEXT: pxor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
8281
; CHECK-SSE2-NEXT: pcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
83-
; CHECK-SSE2-NEXT: pcmpeqd %xmm1, %xmm1
84-
; CHECK-SSE2-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
82+
; CHECK-SSE2-NEXT: por {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
8583
; CHECK-SSE2-NEXT: retq
8684
;
8785
; CHECK-SSE41-LABEL: t1_all_odd_ne:
@@ -92,7 +90,7 @@ define <4 x i1> @t1_all_odd_ne(<4 x i32> %X) nounwind {
9290
; CHECK-SSE41-NEXT: pcmpeqd %xmm1, %xmm0
9391
; CHECK-SSE41-NEXT: pcmpeqd %xmm1, %xmm1
9492
; CHECK-SSE41-NEXT: pxor %xmm1, %xmm0
95-
; CHECK-SSE41-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5],xmm1[6,7]
93+
; CHECK-SSE41-NEXT: por {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
9694
; CHECK-SSE41-NEXT: retq
9795
;
9896
; CHECK-AVX1-LABEL: t1_all_odd_ne:
@@ -102,7 +100,7 @@ define <4 x i1> @t1_all_odd_ne(<4 x i32> %X) nounwind {
102100
; CHECK-AVX1-NEXT: vpcmpeqd %xmm1, %xmm0, %xmm0
103101
; CHECK-AVX1-NEXT: vpcmpeqd %xmm1, %xmm1, %xmm1
104102
; CHECK-AVX1-NEXT: vpxor %xmm1, %xmm0, %xmm0
105-
; CHECK-AVX1-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5],xmm1[6,7]
103+
; CHECK-AVX1-NEXT: vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
106104
; CHECK-AVX1-NEXT: retq
107105
;
108106
; CHECK-AVX2-LABEL: t1_all_odd_ne:
@@ -113,17 +111,16 @@ define <4 x i1> @t1_all_odd_ne(<4 x i32> %X) nounwind {
113111
; CHECK-AVX2-NEXT: vpcmpeqd %xmm1, %xmm0, %xmm0
114112
; CHECK-AVX2-NEXT: vpcmpeqd %xmm1, %xmm1, %xmm1
115113
; CHECK-AVX2-NEXT: vpxor %xmm1, %xmm0, %xmm0
116-
; CHECK-AVX2-NEXT: vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
114+
; CHECK-AVX2-NEXT: vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
117115
; CHECK-AVX2-NEXT: retq
118116
;
119117
; CHECK-AVX512VL-LABEL: t1_all_odd_ne:
120118
; CHECK-AVX512VL: # %bb.0:
121119
; CHECK-AVX512VL-NEXT: vpmulld {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to4}, %xmm0, %xmm0
122120
; CHECK-AVX512VL-NEXT: vpminud {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1
123-
; CHECK-AVX512VL-NEXT: vpcmpeqd %xmm1, %xmm0, %xmm0
124-
; CHECK-AVX512VL-NEXT: vpternlogq {{.*#+}} xmm0 = ~xmm0
125-
; CHECK-AVX512VL-NEXT: vpcmpeqd %xmm1, %xmm1, %xmm1
126-
; CHECK-AVX512VL-NEXT: vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
121+
; CHECK-AVX512VL-NEXT: vpcmpeqd %xmm1, %xmm0, %xmm1
122+
; CHECK-AVX512VL-NEXT: vpcmpeqd %xmm0, %xmm0, %xmm0
123+
; CHECK-AVX512VL-NEXT: vpternlogq {{.*#+}} xmm0 = m64bcst | (xmm0 ^ xmm1)
127124
; CHECK-AVX512VL-NEXT: retq
128125
%urem = urem <4 x i32> %X, <i32 3, i32 1, i32 1, i32 9>
129126
%cmp = icmp ne <4 x i32> %urem, <i32 0, i32 42, i32 0, i32 42>

0 commit comments

Comments
 (0)