Skip to content

Commit c67e83a

Browse files
committed
[X86] Improve helper for simplifying demanded bits of compares
We currently only handle a single case for `pcmpgt`. This patch extends that to work for `cmpp` and handles comparitors more generically.
1 parent 74b08f8 commit c67e83a

File tree

7 files changed

+253
-47
lines changed

7 files changed

+253
-47
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 188 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41346,6 +41346,156 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG,
4134641346
return SDValue();
4134741347
}
4134841348

41349+
// Simplify a decomposed (sext (setcc)). Assumes prior check that
41350+
// bitwidth(sext)==bitwidth(setcc operands).
41351+
static SDValue simplifySExtOfDecomposedSetCCImpl(
41352+
SelectionDAG &DAG, const SDLoc &DL, ISD::CondCode CC, SDValue Op0,
41353+
SDValue Op1, const APInt &OriginalDemandedBits,
41354+
const APInt &OriginalDemandedElts, bool AllowNOT, unsigned Depth) {
41355+
// Possible TODO: We could handle any power of two demanded bit + unsigned
41356+
// comparison. There are no x86 specific comparisons that are unsigned so its
41357+
// unneeded.
41358+
if (!OriginalDemandedBits.isSignMask())
41359+
return SDValue();
41360+
41361+
EVT OpVT = Op0.getValueType();
41362+
// We need need nofpclass(nan inf nzero) to handle floats.
41363+
auto hasOkayFPFlags = [](SDValue Op) {
41364+
return Op.getOpcode() == ISD::SINT_TO_FP ||
41365+
Op.getOpcode() == ISD::UINT_TO_FP ||
41366+
(Op->getFlags().hasNoNaNs() && Op->getFlags().hasNoInfs() &&
41367+
Op->getFlags().hasNoSignedZeros());
41368+
};
41369+
41370+
if (OpVT.isFloatingPoint() && !hasOkayFPFlags(Op0))
41371+
return SDValue();
41372+
41373+
auto ValsEq = [OpVT](const APInt &V0, APInt V1) -> bool {
41374+
if (OpVT.isFloatingPoint()) {
41375+
const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(OpVT);
41376+
return V0.eq(APFloat(Sem, V1).bitcastToAPInt());
41377+
}
41378+
return V0.eq(V1);
41379+
};
41380+
41381+
// Assume we canonicalized constants to Op1. That isn't always true but we
41382+
// call this function twice with inverted CC/Operands so its fine either way.
41383+
APInt Op1C;
41384+
unsigned ValWidth = OriginalDemandedBits.getBitWidth();
41385+
if (ISD::isConstantSplatVectorAllZeros(Op1.getNode())) {
41386+
Op1C = APInt::getZero(ValWidth);
41387+
} else if (ISD::isConstantSplatVectorAllOnes(Op1.getNode())) {
41388+
Op1C = APInt::getAllOnes(ValWidth);
41389+
} else if (auto *C = dyn_cast<ConstantFPSDNode>(Op1)) {
41390+
Op1C = C->getValueAPF().bitcastToAPInt();
41391+
} else if (auto *C = dyn_cast<ConstantSDNode>(Op1)) {
41392+
Op1C = C->getAPIntValue();
41393+
} else if (ISD::isConstantSplatVector(Op1.getNode(), Op1C)) {
41394+
// isConstantSplatVector sets `Op1C`.
41395+
} else {
41396+
return SDValue();
41397+
}
41398+
41399+
bool Not = false;
41400+
bool Okay = false;
41401+
assert(OriginalDemandedBits.getBitWidth() == Op1C.getBitWidth() &&
41402+
"Invalid constant operand");
41403+
41404+
switch (CC) {
41405+
case ISD::SETGE:
41406+
case ISD::SETOGE:
41407+
Not = true;
41408+
[[fallthrough]];
41409+
case ISD::SETLT:
41410+
case ISD::SETOLT:
41411+
// signbit(sext(x s< 0)) == signbit(x)
41412+
// signbit(sext(x s>= 0)) == signbit(~x)
41413+
Okay = ValsEq(Op1C, APInt::getZero(ValWidth));
41414+
// For float ops we need to ensure Op0 is de-norm. Otherwise DAZ can break
41415+
// this fold.
41416+
// NB: We only need de-norm check here, for the rest of the constants any
41417+
// relationship with a de-norm value and zero will be identical.
41418+
if (Okay && OpVT.isFloatingPoint()) {
41419+
// Values from integers are always normal.
41420+
if (Op0.getOpcode() == ISD::SINT_TO_FP ||
41421+
Op0.getOpcode() == ISD::UINT_TO_FP)
41422+
break;
41423+
41424+
// See if we can prove normal with known bits.
41425+
KnownBits Op0Known =
41426+
DAG.computeKnownBits(Op0, OriginalDemandedElts, Depth);
41427+
// Negative/positive doesn't matter.
41428+
Op0Known.One.clearSignBit();
41429+
Op0Known.Zero.clearSignBit();
41430+
41431+
// Get min normal value.
41432+
const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(OpVT);
41433+
KnownBits MinNormal = KnownBits::makeConstant(
41434+
APFloat::getSmallestNormalized(Sem).bitcastToAPInt());
41435+
// Are we above de-norm range?
41436+
std::optional<bool> Op0Normal = KnownBits::uge(Op0Known, MinNormal);
41437+
Okay = Op0Normal.value_or(false);
41438+
}
41439+
break;
41440+
case ISD::SETGT:
41441+
case ISD::SETOGT:
41442+
Not = true;
41443+
[[fallthrough]];
41444+
case ISD::SETLE:
41445+
case ISD::SETOLE:
41446+
// signbit(sext(x s<= -1)) == signbit(x)
41447+
// signbit(sext(x s> -1)) == signbit(~x)
41448+
Okay = ValsEq(Op1C, APInt::getAllOnes(ValWidth));
41449+
break;
41450+
case ISD::SETULT:
41451+
Not = true;
41452+
[[fallthrough]];
41453+
case ISD::SETUGE:
41454+
// signbit(sext(x u>= SIGNED_MIN)) == signbit(x)
41455+
// signbit(sext(x u< SIGNED_MIN)) == signbit(~x)
41456+
Okay = !OpVT.isFloatingPoint() && ValsEq(Op1C, OriginalDemandedBits);
41457+
break;
41458+
case ISD::SETULE:
41459+
Not = true;
41460+
[[fallthrough]];
41461+
case ISD::SETUGT:
41462+
// signbit(sext(x u> SIGNED_MAX)) == signbit(x)
41463+
// signbit(sext(x u<= SIGNED_MAX)) == signbit(~x)
41464+
Okay = !OpVT.isFloatingPoint() && ValsEq(Op1C, OriginalDemandedBits - 1);
41465+
break;
41466+
default:
41467+
break;
41468+
}
41469+
41470+
Okay &= Not ? AllowNOT : true;
41471+
if (!Okay)
41472+
return SDValue();
41473+
41474+
if (!Not)
41475+
return Op0;
41476+
41477+
if (!OpVT.isFloatingPoint())
41478+
return DAG.getNOT(DL, Op0, OpVT);
41479+
41480+
// Possible TODO: We could use `fneg` to do not.
41481+
return SDValue();
41482+
}
41483+
41484+
static SDValue simplifySExtOfDecomposedSetCC(SelectionDAG &DAG, const SDLoc &DL,
41485+
ISD::CondCode CC, SDValue Op0,
41486+
SDValue Op1,
41487+
const APInt &OriginalDemandedBits,
41488+
const APInt &OriginalDemandedElts,
41489+
bool AllowNOT, unsigned Depth) {
41490+
if (SDValue R = simplifySExtOfDecomposedSetCCImpl(
41491+
DAG, DL, CC, Op0, Op1, OriginalDemandedBits, OriginalDemandedElts,
41492+
AllowNOT, Depth))
41493+
return R;
41494+
return simplifySExtOfDecomposedSetCCImpl(
41495+
DAG, DL, ISD::getSetCCSwappedOperands(CC), Op1, Op0, OriginalDemandedBits,
41496+
OriginalDemandedElts, AllowNOT, Depth);
41497+
}
41498+
4134941499
// Simplify variable target shuffle masks based on the demanded elements.
4135041500
// TODO: Handle DemandedBits in mask indices as well?
4135141501
bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetShuffle(
@@ -42525,13 +42675,26 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
4252542675
}
4252642676
break;
4252742677
}
42528-
case X86ISD::PCMPGT:
42529-
// icmp sgt(0, R) == ashr(R, BitWidth-1).
42530-
// iff we only need the sign bit then we can use R directly.
42531-
if (OriginalDemandedBits.isSignMask() &&
42532-
ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode()))
42533-
return TLO.CombineTo(Op, Op.getOperand(1));
42678+
case X86ISD::PCMPGT: {
42679+
SDLoc DL(Op);
42680+
if (SDValue R = simplifySExtOfDecomposedSetCC(
42681+
TLO.DAG, DL, ISD::SETGT, Op.getOperand(0), Op.getOperand(1),
42682+
OriginalDemandedBits, OriginalDemandedElts,
42683+
/*AllowNOT*/ true, Depth))
42684+
return TLO.CombineTo(Op, R);
42685+
break;
42686+
}
42687+
case X86ISD::CMPP: {
42688+
SDLoc DL(Op);
42689+
ISD::CondCode CC = X86::getCondForCMPPImm(
42690+
cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue());
42691+
if (SDValue R = simplifySExtOfDecomposedSetCC(
42692+
TLO.DAG, DL, CC, Op.getOperand(0), Op.getOperand(1),
42693+
OriginalDemandedBits, OriginalDemandedElts,
42694+
!(TLO.LegalOperations() && TLO.LegalTypes()), Depth))
42695+
return TLO.CombineTo(Op, R);
4253442696
break;
42697+
}
4253542698
case X86ISD::MOVMSK: {
4253642699
SDValue Src = Op.getOperand(0);
4253742700
MVT SrcVT = Src.getSimpleValueType();
@@ -42715,13 +42878,25 @@ SDValue X86TargetLowering::SimplifyMultipleUseDemandedBitsForTargetNode(
4271542878
if (DemandedBits.isSignMask())
4271642879
return Op.getOperand(0);
4271742880
break;
42718-
case X86ISD::PCMPGT:
42719-
// icmp sgt(0, R) == ashr(R, BitWidth-1).
42720-
// iff we only need the sign bit then we can use R directly.
42721-
if (DemandedBits.isSignMask() &&
42722-
ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode()))
42723-
return Op.getOperand(1);
42881+
case X86ISD::PCMPGT: {
42882+
SDLoc DL(Op);
42883+
if (SDValue R = simplifySExtOfDecomposedSetCC(
42884+
DAG, DL, ISD::SETGT, Op.getOperand(0), Op.getOperand(1),
42885+
DemandedBits, DemandedElts, /*AllowNOT*/ false, Depth))
42886+
return R;
42887+
break;
42888+
}
42889+
case X86ISD::CMPP: {
42890+
SDLoc DL(Op);
42891+
ISD::CondCode CC = X86::getCondForCMPPImm(
42892+
cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue());
42893+
if (SDValue R = simplifySExtOfDecomposedSetCC(DAG, DL, CC, Op.getOperand(0),
42894+
Op.getOperand(1),
42895+
DemandedBits, DemandedElts,
42896+
/*AllowNOT*/ false, Depth))
42897+
return R;
4272442898
break;
42899+
}
4272542900
case X86ISD::BLENDV: {
4272642901
// BLENDV: Cond (MSB) ? LHS : RHS
4272742902
SDValue Cond = Op.getOperand(0);
@@ -48397,7 +48572,7 @@ static SDValue combineAndShuffleNot(SDNode *N, SelectionDAG &DAG,
4839748572

4839848573
// We do not split for SSE at all, but we need to split vectors for AVX1 and
4839948574
// AVX2.
48400-
if (!Subtarget.useAVX512Regs() && VT.is512BitVector() &&
48575+
if (!Subtarget.useAVX512Regs() && VT.is512BitVector() &&
4840148576
TLI.isTypeLegal(VT.getHalfNumVectorElementsVT(*DAG.getContext()))) {
4840248577
SDValue LoX, HiX;
4840348578
std::tie(LoX, HiX) = splitVector(X, DAG, DL);

llvm/lib/Target/X86/X86InstrInfo.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3360,6 +3360,46 @@ unsigned X86::getVPCMPImmForCond(ISD::CondCode CC) {
33603360
}
33613361
}
33623362

3363+
ISD::CondCode X86::getCondForCMPPImm(unsigned Imm) {
3364+
assert(Imm <= 0x1f && "Invalid CMPP Imm");
3365+
switch (Imm & 0xf) {
3366+
default:
3367+
llvm_unreachable("Invalid CMPP Imm");
3368+
case 0:
3369+
return ISD::SETOEQ;
3370+
case 1:
3371+
return ISD::SETOLT;
3372+
case 2:
3373+
return ISD::SETOLE;
3374+
case 3:
3375+
return ISD::SETUO;
3376+
case 4:
3377+
return ISD::SETUNE;
3378+
case 5:
3379+
return ISD::SETUGE;
3380+
case 6:
3381+
return ISD::SETUGT;
3382+
case 7:
3383+
return ISD::SETO;
3384+
case 8:
3385+
return ISD::SETUEQ;
3386+
case 9:
3387+
return ISD::SETULT;
3388+
case 10:
3389+
return ISD::SETULE;
3390+
case 11:
3391+
return ISD::SETFALSE;
3392+
case 12:
3393+
return ISD::SETONE;
3394+
case 13:
3395+
return ISD::SETOGE;
3396+
case 14:
3397+
return ISD::SETOGT;
3398+
case 15:
3399+
return ISD::SETTRUE;
3400+
}
3401+
}
3402+
33633403
/// Get the VPCMP immediate if the operands are swapped.
33643404
unsigned X86::getSwappedVPCMPImm(unsigned Imm) {
33653405
switch (Imm) {

llvm/lib/Target/X86/X86InstrInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ CondCode GetOppositeBranchCondition(CondCode CC);
7272
/// Get the VPCMP immediate for the given condition.
7373
unsigned getVPCMPImmForCond(ISD::CondCode CC);
7474

75+
/// Get the CondCode from a CMPP immediate.
76+
ISD::CondCode getCondForCMPPImm(unsigned Imm);
77+
7578
/// Get the VPCMP immediate if the opcodes are swapped.
7679
unsigned getSwappedVPCMPImm(unsigned Imm);
7780

llvm/test/CodeGen/X86/combine-testps.ll

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,6 @@ define i32 @testpsnzc_256_signbit(<8 x float> %c, <8 x float> %d, i32 %a, i32 %b
175175
; AVX: # %bb.0:
176176
; AVX-NEXT: movl %edi, %eax
177177
; AVX-NEXT: vcvtdq2ps %ymm0, %ymm0
178-
; AVX-NEXT: vxorps %xmm2, %xmm2, %xmm2
179-
; AVX-NEXT: vcmpltps %ymm2, %ymm0, %ymm0
180178
; AVX-NEXT: vtestps %ymm1, %ymm0
181179
; AVX-NEXT: cmovnel %esi, %eax
182180
; AVX-NEXT: vzeroupper

llvm/test/CodeGen/X86/sadd_sat_vec.ll

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -862,8 +862,6 @@ define <8 x i32> @v8i32(<8 x i32> %x, <8 x i32> %y) nounwind {
862862
; AVX1-NEXT: vpcmpgtd %xmm4, %xmm0, %xmm0
863863
; AVX1-NEXT: vinsertf128 $1, %xmm3, %ymm0, %ymm0
864864
; AVX1-NEXT: vcvtdq2ps %ymm1, %ymm1
865-
; AVX1-NEXT: vpxor %xmm3, %xmm3, %xmm3
866-
; AVX1-NEXT: vcmpltps %ymm3, %ymm1, %ymm1
867865
; AVX1-NEXT: vxorps %ymm0, %ymm1, %ymm0
868866
; AVX1-NEXT: vpsrad $31, %xmm4, %xmm1
869867
; AVX1-NEXT: vpsrad $31, %xmm2, %xmm2
@@ -1063,8 +1061,6 @@ define <16 x i32> @v16i32(<16 x i32> %x, <16 x i32> %y) nounwind {
10631061
; AVX1-NEXT: vpcmpgtd %xmm6, %xmm0, %xmm0
10641062
; AVX1-NEXT: vinsertf128 $1, %xmm5, %ymm0, %ymm0
10651063
; AVX1-NEXT: vcvtdq2ps %ymm2, %ymm2
1066-
; AVX1-NEXT: vpxor %xmm5, %xmm5, %xmm5
1067-
; AVX1-NEXT: vcmpltps %ymm5, %ymm2, %ymm2
10681064
; AVX1-NEXT: vxorps %ymm0, %ymm2, %ymm0
10691065
; AVX1-NEXT: vpsrad $31, %xmm6, %xmm2
10701066
; AVX1-NEXT: vpsrad $31, %xmm4, %xmm4
@@ -1073,21 +1069,20 @@ define <16 x i32> @v16i32(<16 x i32> %x, <16 x i32> %y) nounwind {
10731069
; AVX1-NEXT: vxorps %ymm4, %ymm2, %ymm2
10741070
; AVX1-NEXT: vblendvps %ymm0, %ymm2, %ymm7, %ymm0
10751071
; AVX1-NEXT: vextractf128 $1, %ymm3, %xmm2
1076-
; AVX1-NEXT: vextractf128 $1, %ymm1, %xmm6
1077-
; AVX1-NEXT: vpaddd %xmm2, %xmm6, %xmm2
1078-
; AVX1-NEXT: vpaddd %xmm3, %xmm1, %xmm7
1079-
; AVX1-NEXT: vinsertf128 $1, %xmm2, %ymm7, %ymm8
1080-
; AVX1-NEXT: vpcmpgtd %xmm2, %xmm6, %xmm6
1081-
; AVX1-NEXT: vpcmpgtd %xmm7, %xmm1, %xmm1
1082-
; AVX1-NEXT: vinsertf128 $1, %xmm6, %ymm1, %ymm1
1072+
; AVX1-NEXT: vextractf128 $1, %ymm1, %xmm5
1073+
; AVX1-NEXT: vpaddd %xmm2, %xmm5, %xmm2
1074+
; AVX1-NEXT: vpaddd %xmm3, %xmm1, %xmm6
1075+
; AVX1-NEXT: vinsertf128 $1, %xmm2, %ymm6, %ymm7
1076+
; AVX1-NEXT: vpcmpgtd %xmm2, %xmm5, %xmm5
1077+
; AVX1-NEXT: vpcmpgtd %xmm6, %xmm1, %xmm1
1078+
; AVX1-NEXT: vinsertf128 $1, %xmm5, %ymm1, %ymm1
10831079
; AVX1-NEXT: vcvtdq2ps %ymm3, %ymm3
1084-
; AVX1-NEXT: vcmpltps %ymm5, %ymm3, %ymm3
10851080
; AVX1-NEXT: vxorps %ymm1, %ymm3, %ymm1
1086-
; AVX1-NEXT: vpsrad $31, %xmm7, %xmm3
1081+
; AVX1-NEXT: vpsrad $31, %xmm6, %xmm3
10871082
; AVX1-NEXT: vpsrad $31, %xmm2, %xmm2
10881083
; AVX1-NEXT: vinsertf128 $1, %xmm2, %ymm3, %ymm2
10891084
; AVX1-NEXT: vxorps %ymm4, %ymm2, %ymm2
1090-
; AVX1-NEXT: vblendvps %ymm1, %ymm2, %ymm8, %ymm1
1085+
; AVX1-NEXT: vblendvps %ymm1, %ymm2, %ymm7, %ymm1
10911086
; AVX1-NEXT: retq
10921087
;
10931088
; AVX2-LABEL: v16i32:

llvm/test/CodeGen/X86/vector-pcmp.ll

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1957,35 +1957,32 @@ define <4 x i64> @PR52504(<4 x i16> %t3) {
19571957
; SSE42-LABEL: PR52504:
19581958
; SSE42: # %bb.0:
19591959
; SSE42-NEXT: pshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
1960-
; SSE42-NEXT: pmovsxwq %xmm1, %xmm2
1961-
; SSE42-NEXT: pmovsxwq %xmm0, %xmm3
1962-
; SSE42-NEXT: pxor %xmm1, %xmm1
1963-
; SSE42-NEXT: pxor %xmm0, %xmm0
1964-
; SSE42-NEXT: pcmpgtq %xmm3, %xmm0
1965-
; SSE42-NEXT: por %xmm3, %xmm0
1966-
; SSE42-NEXT: pcmpgtq %xmm2, %xmm1
1967-
; SSE42-NEXT: por %xmm2, %xmm1
1960+
; SSE42-NEXT: pmovsxwq %xmm1, %xmm1
1961+
; SSE42-NEXT: pmovsxwq %xmm0, %xmm2
1962+
; SSE42-NEXT: pcmpeqd %xmm3, %xmm3
1963+
; SSE42-NEXT: movdqa %xmm2, %xmm0
1964+
; SSE42-NEXT: blendvpd %xmm0, %xmm3, %xmm2
1965+
; SSE42-NEXT: movdqa %xmm1, %xmm0
1966+
; SSE42-NEXT: blendvpd %xmm0, %xmm3, %xmm1
1967+
; SSE42-NEXT: movapd %xmm2, %xmm0
19681968
; SSE42-NEXT: retq
19691969
;
19701970
; AVX1-LABEL: PR52504:
19711971
; AVX1: # %bb.0:
19721972
; AVX1-NEXT: vpmovsxwq %xmm0, %xmm1
19731973
; AVX1-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[1,1,1,1]
19741974
; AVX1-NEXT: vpmovsxwq %xmm0, %xmm0
1975-
; AVX1-NEXT: vpxor %xmm2, %xmm2, %xmm2
1976-
; AVX1-NEXT: vpcmpgtq %xmm0, %xmm2, %xmm3
1977-
; AVX1-NEXT: vpor %xmm0, %xmm3, %xmm0
1978-
; AVX1-NEXT: vpcmpgtq %xmm1, %xmm2, %xmm2
1979-
; AVX1-NEXT: vpor %xmm1, %xmm2, %xmm1
1975+
; AVX1-NEXT: vpcmpeqd %xmm2, %xmm2, %xmm2
1976+
; AVX1-NEXT: vblendvpd %xmm0, %xmm2, %xmm0, %xmm0
1977+
; AVX1-NEXT: vblendvpd %xmm1, %xmm2, %xmm1, %xmm1
19801978
; AVX1-NEXT: vinsertf128 $1, %xmm0, %ymm1, %ymm0
19811979
; AVX1-NEXT: retq
19821980
;
19831981
; AVX2-LABEL: PR52504:
19841982
; AVX2: # %bb.0:
19851983
; AVX2-NEXT: vpmovsxwq %xmm0, %ymm0
1986-
; AVX2-NEXT: vpxor %xmm1, %xmm1, %xmm1
1987-
; AVX2-NEXT: vpcmpgtq %ymm0, %ymm1, %ymm1
1988-
; AVX2-NEXT: vpor %ymm0, %ymm1, %ymm0
1984+
; AVX2-NEXT: vpcmpeqd %ymm1, %ymm1, %ymm1
1985+
; AVX2-NEXT: vblendvpd %ymm0, %ymm1, %ymm0, %ymm0
19891986
; AVX2-NEXT: retq
19901987
;
19911988
; AVX512-LABEL: PR52504:

0 commit comments

Comments
 (0)