Skip to content

Commit f6b0d17

Browse files
committed
[X86] Improve helper for simplifying demanded bits of compares
We currently only handle a single case for `pcmpgt`. This patch extends that to work for `cmpp` and handles comparitors more generically.
1 parent bc5d523 commit f6b0d17

File tree

6 files changed

+258
-64
lines changed

6 files changed

+258
-64
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 186 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41340,6 +41340,154 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG,
4134041340
return SDValue();
4134141341
}
4134241342

41343+
// Simplify a decomposed (sext (setcc)). Assumes prior check that
41344+
// bitwidth(sext)==bitwidth(setcc operands).
41345+
static SDValue simplifySExtOfDecomposedSetCCImpl(
41346+
SelectionDAG &DAG, const SDLoc &DL, ISD::CondCode CC, SDValue Op0,
41347+
SDValue Op1, const APInt &OriginalDemandedBits,
41348+
const APInt &OriginalDemandedElts, bool AllowNOT, unsigned Depth) {
41349+
// Possible TODO: We could handle any power of two demanded bit + unsigned
41350+
// comparison. There are no x86 specific comparisons that are unsigned so its
41351+
// unneeded.
41352+
if (!OriginalDemandedBits.isSignMask())
41353+
return SDValue();
41354+
41355+
EVT OpVT = Op0.getValueType();
41356+
// We need need nofpclass(nan inf nzero) to handle floats.
41357+
auto hasOkayFPFlags = [](SDValue Op) {
41358+
return Op->getFlags().hasNoNaNs() && Op->getFlags().hasNoInfs() &&
41359+
Op->getFlags().hasNoSignedZeros();
41360+
};
41361+
41362+
if (OpVT.isFloatingPoint() && !hasOkayFPFlags(Op0))
41363+
return SDValue();
41364+
41365+
auto ValsEq = [OpVT](const APInt &V0, APInt V1) -> bool {
41366+
if (OpVT.isFloatingPoint()) {
41367+
const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(OpVT);
41368+
return V0.eq(APFloat(Sem, V1).bitcastToAPInt());
41369+
}
41370+
return V0.eq(V1);
41371+
};
41372+
41373+
// Assume we canonicalized constants to Op1. That isn't always true but we
41374+
// call this function twice with inverted CC/Operands so its fine either way.
41375+
APInt Op1C;
41376+
unsigned ValWidth = OriginalDemandedBits.getBitWidth();
41377+
if (ISD::isConstantSplatVectorAllZeros(Op1.getNode())) {
41378+
Op1C = APInt::getZero(ValWidth);
41379+
} else if (ISD::isConstantSplatVectorAllOnes(Op1.getNode())) {
41380+
Op1C = APInt::getAllOnes(ValWidth);
41381+
} else if (auto *C = dyn_cast<ConstantFPSDNode>(Op1)) {
41382+
Op1C = C->getValueAPF().bitcastToAPInt();
41383+
} else if (auto *C = dyn_cast<ConstantSDNode>(Op1)) {
41384+
Op1C = C->getAPIntValue();
41385+
} else if (ISD::isConstantSplatVector(Op1.getNode(), Op1C)) {
41386+
// isConstantSplatVector sets `Op1C`.
41387+
} else {
41388+
return SDValue();
41389+
}
41390+
41391+
bool Not = false;
41392+
bool Okay = false;
41393+
assert(OriginalDemandedBits.getBitWidth() == Op1C.getBitWidth() &&
41394+
"Invalid constant operand");
41395+
41396+
switch (CC) {
41397+
case ISD::SETGE:
41398+
case ISD::SETOGE:
41399+
Not = true;
41400+
[[fallthrough]];
41401+
case ISD::SETLT:
41402+
case ISD::SETOLT:
41403+
// signbit(sext(x s< 0)) == signbit(x)
41404+
// signbit(sext(x s>= 0)) == signbit(~x)
41405+
Okay = ValsEq(Op1C, APInt::getZero(ValWidth));
41406+
// For float ops we need to ensure Op0 is de-norm. Otherwise DAZ can break
41407+
// this fold.
41408+
// NB: We only need de-norm check here, for the rest of the constants any
41409+
// relationship with a de-norm value and zero will be identical.
41410+
if (Okay && OpVT.isFloatingPoint()) {
41411+
// Values from integers are always normal.
41412+
if (Op0.getOpcode() == ISD::SINT_TO_FP ||
41413+
Op0.getOpcode() == ISD::UINT_TO_FP)
41414+
break;
41415+
41416+
// See if we can prove normal with known bits.
41417+
KnownBits Op0Known =
41418+
DAG.computeKnownBits(Op0, OriginalDemandedElts, Depth);
41419+
// Negative/positive doesn't matter.
41420+
Op0Known.One.clearSignBit();
41421+
Op0Known.Zero.clearSignBit();
41422+
41423+
// Get min normal value.
41424+
const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(OpVT);
41425+
KnownBits MinNormal = KnownBits::makeConstant(
41426+
APFloat::getSmallestNormalized(Sem).bitcastToAPInt());
41427+
// Are we above de-norm range?
41428+
std::optional<bool> Op0Normal = KnownBits::uge(Op0Known, MinNormal);
41429+
Okay = Op0Normal.value_or(false);
41430+
}
41431+
break;
41432+
case ISD::SETGT:
41433+
case ISD::SETOGT:
41434+
Not = true;
41435+
[[fallthrough]];
41436+
case ISD::SETLE:
41437+
case ISD::SETOLE:
41438+
// signbit(sext(x s<= -1)) == signbit(x)
41439+
// signbit(sext(x s> -1)) == signbit(~x)
41440+
Okay = ValsEq(Op1C, APInt::getAllOnes(ValWidth));
41441+
break;
41442+
case ISD::SETULT:
41443+
Not = true;
41444+
[[fallthrough]];
41445+
case ISD::SETUGE:
41446+
// signbit(sext(x u>= SIGNED_MIN)) == signbit(x)
41447+
// signbit(sext(x u< SIGNED_MIN)) == signbit(~x)
41448+
Okay = !OpVT.isFloatingPoint() && ValsEq(Op1C, OriginalDemandedBits);
41449+
break;
41450+
case ISD::SETULE:
41451+
Not = true;
41452+
[[fallthrough]];
41453+
case ISD::SETUGT:
41454+
// signbit(sext(x u> SIGNED_MAX)) == signbit(x)
41455+
// signbit(sext(x u<= SIGNED_MAX)) == signbit(~x)
41456+
Okay = !OpVT.isFloatingPoint() && ValsEq(Op1C, OriginalDemandedBits - 1);
41457+
break;
41458+
default:
41459+
break;
41460+
}
41461+
41462+
Okay &= Not ? AllowNOT : true;
41463+
if (!Okay)
41464+
return SDValue();
41465+
41466+
if (!Not)
41467+
return Op0;
41468+
41469+
if (!OpVT.isFloatingPoint())
41470+
return DAG.getNOT(DL, Op0, OpVT);
41471+
41472+
// Possible TODO: We could use `fneg` to do not.
41473+
return SDValue();
41474+
}
41475+
41476+
static SDValue simplifySExtOfDecomposedSetCC(SelectionDAG &DAG, SDLoc &DL,
41477+
ISD::CondCode CC, SDValue Op0,
41478+
SDValue Op1,
41479+
const APInt &OriginalDemandedBits,
41480+
const APInt &OriginalDemandedElts,
41481+
bool AllowNOT, unsigned Depth) {
41482+
if (SDValue R = simplifySExtOfDecomposedSetCCImpl(
41483+
DAG, DL, CC, Op0, Op1, OriginalDemandedBits, OriginalDemandedElts,
41484+
AllowNOT, Depth))
41485+
return R;
41486+
return simplifySExtOfDecomposedSetCCImpl(
41487+
DAG, DL, ISD::getSetCCSwappedOperands(CC), Op1, Op0, OriginalDemandedBits,
41488+
OriginalDemandedElts, AllowNOT, Depth);
41489+
}
41490+
4134341491
// Simplify variable target shuffle masks based on the demanded elements.
4134441492
// TODO: Handle DemandedBits in mask indices as well?
4134541493
bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetShuffle(
@@ -42519,13 +42667,26 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
4251942667
}
4252042668
break;
4252142669
}
42522-
case X86ISD::PCMPGT:
42523-
// icmp sgt(0, R) == ashr(R, BitWidth-1).
42524-
// iff we only need the sign bit then we can use R directly.
42525-
if (OriginalDemandedBits.isSignMask() &&
42526-
ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode()))
42527-
return TLO.CombineTo(Op, Op.getOperand(1));
42670+
case X86ISD::PCMPGT: {
42671+
SDLoc DL(Op);
42672+
if (SDValue R = simplifySExtOfDecomposedSetCC(
42673+
TLO.DAG, DL, ISD::SETGT, Op.getOperand(0), Op.getOperand(1),
42674+
OriginalDemandedBits, OriginalDemandedElts,
42675+
/*AllowNOT*/ true, Depth))
42676+
return TLO.CombineTo(Op, R);
42677+
break;
42678+
}
42679+
case X86ISD::CMPP: {
42680+
SDLoc DL(Op);
42681+
ISD::CondCode CC = X86::getCondForCMPPImm(
42682+
cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue());
42683+
if (SDValue R = simplifySExtOfDecomposedSetCC(
42684+
TLO.DAG, DL, CC, Op.getOperand(0), Op.getOperand(1),
42685+
OriginalDemandedBits, OriginalDemandedElts,
42686+
!(TLO.LegalOperations() && TLO.LegalTypes()), Depth))
42687+
return TLO.CombineTo(Op, R);
4252842688
break;
42689+
}
4252942690
case X86ISD::MOVMSK: {
4253042691
SDValue Src = Op.getOperand(0);
4253142692
MVT SrcVT = Src.getSimpleValueType();
@@ -42709,13 +42870,25 @@ SDValue X86TargetLowering::SimplifyMultipleUseDemandedBitsForTargetNode(
4270942870
if (DemandedBits.isSignMask())
4271042871
return Op.getOperand(0);
4271142872
break;
42712-
case X86ISD::PCMPGT:
42713-
// icmp sgt(0, R) == ashr(R, BitWidth-1).
42714-
// iff we only need the sign bit then we can use R directly.
42715-
if (DemandedBits.isSignMask() &&
42716-
ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode()))
42717-
return Op.getOperand(1);
42873+
case X86ISD::PCMPGT: {
42874+
SDLoc DL(Op);
42875+
if (SDValue R = simplifySExtOfDecomposedSetCC(
42876+
DAG, DL, ISD::SETGT, Op.getOperand(0), Op.getOperand(1),
42877+
DemandedBits, DemandedElts, /*AllowNOT*/ false, Depth))
42878+
return R;
42879+
break;
42880+
}
42881+
case X86ISD::CMPP: {
42882+
SDLoc DL(Op);
42883+
ISD::CondCode CC = X86::getCondForCMPPImm(
42884+
cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue());
42885+
if (SDValue R = simplifySExtOfDecomposedSetCC(DAG, DL, CC, Op.getOperand(0),
42886+
Op.getOperand(1),
42887+
DemandedBits, DemandedElts,
42888+
/*AllowNOT*/ false, Depth))
42889+
return R;
4271842890
break;
42891+
}
4271942892
case X86ISD::BLENDV: {
4272042893
// BLENDV: Cond (MSB) ? LHS : RHS
4272142894
SDValue Cond = Op.getOperand(0);
@@ -48391,7 +48564,7 @@ static SDValue combineAndShuffleNot(SDNode *N, SelectionDAG &DAG,
4839148564

4839248565
// We do not split for SSE at all, but we need to split vectors for AVX1 and
4839348566
// AVX2.
48394-
if (!Subtarget.useAVX512Regs() && VT.is512BitVector() &&
48567+
if (!Subtarget.useAVX512Regs() && VT.is512BitVector() &&
4839548568
TLI.isTypeLegal(VT.getHalfNumVectorElementsVT(*DAG.getContext()))) {
4839648569
SDValue LoX, HiX;
4839748570
std::tie(LoX, HiX) = splitVector(X, DAG, DL);

llvm/lib/Target/X86/X86InstrInfo.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3349,6 +3349,46 @@ unsigned X86::getVPCMPImmForCond(ISD::CondCode CC) {
33493349
}
33503350
}
33513351

3352+
ISD::CondCode X86::getCondForCMPPImm(unsigned Imm) {
3353+
assert(Imm <= 0x1f && "Invalid CMPP Imm");
3354+
switch (Imm & 0xf) {
3355+
default:
3356+
llvm_unreachable("Invalid CMPP Imm");
3357+
case 0:
3358+
return ISD::SETOEQ;
3359+
case 1:
3360+
return ISD::SETOLT;
3361+
case 2:
3362+
return ISD::SETOLE;
3363+
case 3:
3364+
return ISD::SETUO;
3365+
case 4:
3366+
return ISD::SETUNE;
3367+
case 5:
3368+
return ISD::SETUGE;
3369+
case 6:
3370+
return ISD::SETUGT;
3371+
case 7:
3372+
return ISD::SETO;
3373+
case 8:
3374+
return ISD::SETUEQ;
3375+
case 9:
3376+
return ISD::SETULT;
3377+
case 10:
3378+
return ISD::SETULE;
3379+
case 11:
3380+
return ISD::SETFALSE;
3381+
case 12:
3382+
return ISD::SETONE;
3383+
case 13:
3384+
return ISD::SETOGE;
3385+
case 14:
3386+
return ISD::SETOGT;
3387+
case 15:
3388+
return ISD::SETTRUE;
3389+
}
3390+
}
3391+
33523392
/// Get the VPCMP immediate if the operands are swapped.
33533393
unsigned X86::getSwappedVPCMPImm(unsigned Imm) {
33543394
switch (Imm) {

llvm/lib/Target/X86/X86InstrInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ CondCode GetOppositeBranchCondition(CondCode CC);
6868
/// Get the VPCMP immediate for the given condition.
6969
unsigned getVPCMPImmForCond(ISD::CondCode CC);
7070

71+
/// Get the CondCode from a CMPP immediate.
72+
ISD::CondCode getCondForCMPPImm(unsigned Imm);
73+
7174
/// Get the VPCMP immediate if the opcodes are swapped.
7275
unsigned getSwappedVPCMPImm(unsigned Imm);
7376

llvm/test/CodeGen/X86/combine-testps.ll

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -171,24 +171,13 @@ define i32 @testpsz_128_signbit(<4 x float> %c, <4 x float> %d, i32 %a, i32 %b)
171171
}
172172

173173
define i32 @testpsnzc_256_signbit(<8 x float> %c, <8 x float> %d, i32 %a, i32 %b) {
174-
; AVX-LABEL: testpsnzc_256_signbit:
175-
; AVX: # %bb.0:
176-
; AVX-NEXT: movl %edi, %eax
177-
; AVX-NEXT: vcvtdq2ps %ymm0, %ymm0
178-
; AVX-NEXT: vxorps %xmm2, %xmm2, %xmm2
179-
; AVX-NEXT: vcmpltps %ymm2, %ymm0, %ymm0
180-
; AVX-NEXT: vtestps %ymm1, %ymm0
181-
; AVX-NEXT: cmovnel %esi, %eax
182-
; AVX-NEXT: vzeroupper
183-
; AVX-NEXT: retq
184-
;
185-
; AVX2-LABEL: testpsnzc_256_signbit:
186-
; AVX2: # %bb.0:
187-
; AVX2-NEXT: movl %edi, %eax
188-
; AVX2-NEXT: vtestps %ymm1, %ymm0
189-
; AVX2-NEXT: cmovnel %esi, %eax
190-
; AVX2-NEXT: vzeroupper
191-
; AVX2-NEXT: retq
174+
; CHECK-LABEL: testpsnzc_256_signbit:
175+
; CHECK: # %bb.0:
176+
; CHECK-NEXT: movl %edi, %eax
177+
; CHECK-NEXT: vtestps %ymm1, %ymm0
178+
; CHECK-NEXT: cmovnel %esi, %eax
179+
; CHECK-NEXT: vzeroupper
180+
; CHECK-NEXT: retq
192181
%t0 = bitcast <8 x float> %c to <8 x i32>
193182
%t1 = icmp sgt <8 x i32> zeroinitializer, %t0
194183
%t2 = sext <8 x i1> %t1 to <8 x i32>

llvm/test/CodeGen/X86/sadd_sat_vec.ll

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -861,9 +861,6 @@ define <8 x i32> @v8i32(<8 x i32> %x, <8 x i32> %y) nounwind {
861861
; AVX1-NEXT: vpcmpgtd %xmm2, %xmm3, %xmm3
862862
; AVX1-NEXT: vpcmpgtd %xmm4, %xmm0, %xmm0
863863
; AVX1-NEXT: vinsertf128 $1, %xmm3, %ymm0, %ymm0
864-
; AVX1-NEXT: vcvtdq2ps %ymm1, %ymm1
865-
; AVX1-NEXT: vpxor %xmm3, %xmm3, %xmm3
866-
; AVX1-NEXT: vcmpltps %ymm3, %ymm1, %ymm1
867864
; AVX1-NEXT: vxorps %ymm0, %ymm1, %ymm0
868865
; AVX1-NEXT: vpsrad $31, %xmm4, %xmm1
869866
; AVX1-NEXT: vpsrad $31, %xmm2, %xmm2
@@ -1062,9 +1059,6 @@ define <16 x i32> @v16i32(<16 x i32> %x, <16 x i32> %y) nounwind {
10621059
; AVX1-NEXT: vpcmpgtd %xmm4, %xmm5, %xmm5
10631060
; AVX1-NEXT: vpcmpgtd %xmm6, %xmm0, %xmm0
10641061
; AVX1-NEXT: vinsertf128 $1, %xmm5, %ymm0, %ymm0
1065-
; AVX1-NEXT: vcvtdq2ps %ymm2, %ymm2
1066-
; AVX1-NEXT: vpxor %xmm5, %xmm5, %xmm5
1067-
; AVX1-NEXT: vcmpltps %ymm5, %ymm2, %ymm2
10681062
; AVX1-NEXT: vxorps %ymm0, %ymm2, %ymm0
10691063
; AVX1-NEXT: vpsrad $31, %xmm6, %xmm2
10701064
; AVX1-NEXT: vpsrad $31, %xmm4, %xmm4
@@ -1073,21 +1067,19 @@ define <16 x i32> @v16i32(<16 x i32> %x, <16 x i32> %y) nounwind {
10731067
; AVX1-NEXT: vxorps %ymm4, %ymm2, %ymm2
10741068
; AVX1-NEXT: vblendvps %ymm0, %ymm2, %ymm7, %ymm0
10751069
; AVX1-NEXT: vextractf128 $1, %ymm3, %xmm2
1076-
; AVX1-NEXT: vextractf128 $1, %ymm1, %xmm6
1077-
; AVX1-NEXT: vpaddd %xmm2, %xmm6, %xmm2
1078-
; AVX1-NEXT: vpaddd %xmm3, %xmm1, %xmm7
1079-
; AVX1-NEXT: vinsertf128 $1, %xmm2, %ymm7, %ymm8
1080-
; AVX1-NEXT: vpcmpgtd %xmm2, %xmm6, %xmm6
1081-
; AVX1-NEXT: vpcmpgtd %xmm7, %xmm1, %xmm1
1082-
; AVX1-NEXT: vinsertf128 $1, %xmm6, %ymm1, %ymm1
1083-
; AVX1-NEXT: vcvtdq2ps %ymm3, %ymm3
1084-
; AVX1-NEXT: vcmpltps %ymm5, %ymm3, %ymm3
1070+
; AVX1-NEXT: vextractf128 $1, %ymm1, %xmm5
1071+
; AVX1-NEXT: vpaddd %xmm2, %xmm5, %xmm2
1072+
; AVX1-NEXT: vpaddd %xmm3, %xmm1, %xmm6
1073+
; AVX1-NEXT: vinsertf128 $1, %xmm2, %ymm6, %ymm7
1074+
; AVX1-NEXT: vpcmpgtd %xmm2, %xmm5, %xmm5
1075+
; AVX1-NEXT: vpcmpgtd %xmm6, %xmm1, %xmm1
1076+
; AVX1-NEXT: vinsertf128 $1, %xmm5, %ymm1, %ymm1
10851077
; AVX1-NEXT: vxorps %ymm1, %ymm3, %ymm1
1086-
; AVX1-NEXT: vpsrad $31, %xmm7, %xmm3
1078+
; AVX1-NEXT: vpsrad $31, %xmm6, %xmm3
10871079
; AVX1-NEXT: vpsrad $31, %xmm2, %xmm2
10881080
; AVX1-NEXT: vinsertf128 $1, %xmm2, %ymm3, %ymm2
10891081
; AVX1-NEXT: vxorps %ymm4, %ymm2, %ymm2
1090-
; AVX1-NEXT: vblendvps %ymm1, %ymm2, %ymm8, %ymm1
1082+
; AVX1-NEXT: vblendvps %ymm1, %ymm2, %ymm7, %ymm1
10911083
; AVX1-NEXT: retq
10921084
;
10931085
; AVX2-LABEL: v16i32:

0 commit comments

Comments
 (0)