Skip to content

Commit 0faca34

Browse files
committed
[X86] Improve helper for simplifying demanded bits of compares
We currently only handle a single case for `pcmpgt`. This patch extends that to work for `cmpp` and handles comparitors more generically.
1 parent ba7d9d1 commit 0faca34

File tree

6 files changed

+258
-64
lines changed

6 files changed

+258
-64
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 186 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41341,6 +41341,154 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG,
4134141341
return SDValue();
4134241342
}
4134341343

41344+
// Simplify a decomposed (sext (setcc)). Assumes prior check that
41345+
// bitwidth(sext)==bitwidth(setcc operands).
41346+
static SDValue simplifySExtOfDecomposedSetCCImpl(
41347+
SelectionDAG &DAG, const SDLoc &DL, ISD::CondCode CC, SDValue Op0,
41348+
SDValue Op1, const APInt &OriginalDemandedBits,
41349+
const APInt &OriginalDemandedElts, bool AllowNOT, unsigned Depth) {
41350+
// Possible TODO: We could handle any power of two demanded bit + unsigned
41351+
// comparison. There are no x86 specific comparisons that are unsigned so its
41352+
// unneeded.
41353+
if (!OriginalDemandedBits.isSignMask())
41354+
return SDValue();
41355+
41356+
EVT OpVT = Op0.getValueType();
41357+
// We need need nofpclass(nan inf nzero) to handle floats.
41358+
auto hasOkayFPFlags = [](SDValue Op) {
41359+
return Op->getFlags().hasNoNaNs() && Op->getFlags().hasNoInfs() &&
41360+
Op->getFlags().hasNoSignedZeros();
41361+
};
41362+
41363+
if (OpVT.isFloatingPoint() && !hasOkayFPFlags(Op0))
41364+
return SDValue();
41365+
41366+
auto ValsEq = [OpVT](const APInt &V0, APInt V1) -> bool {
41367+
if (OpVT.isFloatingPoint()) {
41368+
const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(OpVT);
41369+
return V0.eq(APFloat(Sem, V1).bitcastToAPInt());
41370+
}
41371+
return V0.eq(V1);
41372+
};
41373+
41374+
// Assume we canonicalized constants to Op1. That isn't always true but we
41375+
// call this function twice with inverted CC/Operands so its fine either way.
41376+
APInt Op1C;
41377+
unsigned ValWidth = OriginalDemandedBits.getBitWidth();
41378+
if (ISD::isConstantSplatVectorAllZeros(Op1.getNode())) {
41379+
Op1C = APInt::getZero(ValWidth);
41380+
} else if (ISD::isConstantSplatVectorAllOnes(Op1.getNode())) {
41381+
Op1C = APInt::getAllOnes(ValWidth);
41382+
} else if (auto *C = dyn_cast<ConstantFPSDNode>(Op1)) {
41383+
Op1C = C->getValueAPF().bitcastToAPInt();
41384+
} else if (auto *C = dyn_cast<ConstantSDNode>(Op1)) {
41385+
Op1C = C->getAPIntValue();
41386+
} else if (ISD::isConstantSplatVector(Op1.getNode(), Op1C)) {
41387+
// isConstantSplatVector sets `Op1C`.
41388+
} else {
41389+
return SDValue();
41390+
}
41391+
41392+
bool Not = false;
41393+
bool Okay = false;
41394+
assert(OriginalDemandedBits.getBitWidth() == Op1C.getBitWidth() &&
41395+
"Invalid constant operand");
41396+
41397+
switch (CC) {
41398+
case ISD::SETGE:
41399+
case ISD::SETOGE:
41400+
Not = true;
41401+
[[fallthrough]];
41402+
case ISD::SETLT:
41403+
case ISD::SETOLT:
41404+
// signbit(sext(x s< 0)) == signbit(x)
41405+
// signbit(sext(x s>= 0)) == signbit(~x)
41406+
Okay = ValsEq(Op1C, APInt::getZero(ValWidth));
41407+
// For float ops we need to ensure Op0 is de-norm. Otherwise DAZ can break
41408+
// this fold.
41409+
// NB: We only need de-norm check here, for the rest of the constants any
41410+
// relationship with a de-norm value and zero will be identical.
41411+
if (Okay && OpVT.isFloatingPoint()) {
41412+
// Values from integers are always normal.
41413+
if (Op0.getOpcode() == ISD::SINT_TO_FP ||
41414+
Op0.getOpcode() == ISD::UINT_TO_FP)
41415+
break;
41416+
41417+
// See if we can prove normal with known bits.
41418+
KnownBits Op0Known =
41419+
DAG.computeKnownBits(Op0, OriginalDemandedElts, Depth);
41420+
// Negative/positive doesn't matter.
41421+
Op0Known.One.clearSignBit();
41422+
Op0Known.Zero.clearSignBit();
41423+
41424+
// Get min normal value.
41425+
const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(OpVT);
41426+
KnownBits MinNormal = KnownBits::makeConstant(
41427+
APFloat::getSmallestNormalized(Sem).bitcastToAPInt());
41428+
// Are we above de-norm range?
41429+
std::optional<bool> Op0Normal = KnownBits::uge(Op0Known, MinNormal);
41430+
Okay = Op0Normal.value_or(false);
41431+
}
41432+
break;
41433+
case ISD::SETGT:
41434+
case ISD::SETOGT:
41435+
Not = true;
41436+
[[fallthrough]];
41437+
case ISD::SETLE:
41438+
case ISD::SETOLE:
41439+
// signbit(sext(x s<= -1)) == signbit(x)
41440+
// signbit(sext(x s> -1)) == signbit(~x)
41441+
Okay = ValsEq(Op1C, APInt::getAllOnes(ValWidth));
41442+
break;
41443+
case ISD::SETULT:
41444+
Not = true;
41445+
[[fallthrough]];
41446+
case ISD::SETUGE:
41447+
// signbit(sext(x u>= SIGNED_MIN)) == signbit(x)
41448+
// signbit(sext(x u< SIGNED_MIN)) == signbit(~x)
41449+
Okay = !OpVT.isFloatingPoint() && ValsEq(Op1C, OriginalDemandedBits);
41450+
break;
41451+
case ISD::SETULE:
41452+
Not = true;
41453+
[[fallthrough]];
41454+
case ISD::SETUGT:
41455+
// signbit(sext(x u> SIGNED_MAX)) == signbit(x)
41456+
// signbit(sext(x u<= SIGNED_MAX)) == signbit(~x)
41457+
Okay = !OpVT.isFloatingPoint() && ValsEq(Op1C, OriginalDemandedBits - 1);
41458+
break;
41459+
default:
41460+
break;
41461+
}
41462+
41463+
Okay &= Not ? AllowNOT : true;
41464+
if (!Okay)
41465+
return SDValue();
41466+
41467+
if (!Not)
41468+
return Op0;
41469+
41470+
if (!OpVT.isFloatingPoint())
41471+
return DAG.getNOT(DL, Op0, OpVT);
41472+
41473+
// Possible TODO: We could use `fneg` to do not.
41474+
return SDValue();
41475+
}
41476+
41477+
static SDValue simplifySExtOfDecomposedSetCC(SelectionDAG &DAG, SDLoc &DL,
41478+
ISD::CondCode CC, SDValue Op0,
41479+
SDValue Op1,
41480+
const APInt &OriginalDemandedBits,
41481+
const APInt &OriginalDemandedElts,
41482+
bool AllowNOT, unsigned Depth) {
41483+
if (SDValue R = simplifySExtOfDecomposedSetCCImpl(
41484+
DAG, DL, CC, Op0, Op1, OriginalDemandedBits, OriginalDemandedElts,
41485+
AllowNOT, Depth))
41486+
return R;
41487+
return simplifySExtOfDecomposedSetCCImpl(
41488+
DAG, DL, ISD::getSetCCSwappedOperands(CC), Op1, Op0, OriginalDemandedBits,
41489+
OriginalDemandedElts, AllowNOT, Depth);
41490+
}
41491+
4134441492
// Simplify variable target shuffle masks based on the demanded elements.
4134541493
// TODO: Handle DemandedBits in mask indices as well?
4134641494
bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetShuffle(
@@ -42520,13 +42668,26 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
4252042668
}
4252142669
break;
4252242670
}
42523-
case X86ISD::PCMPGT:
42524-
// icmp sgt(0, R) == ashr(R, BitWidth-1).
42525-
// iff we only need the sign bit then we can use R directly.
42526-
if (OriginalDemandedBits.isSignMask() &&
42527-
ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode()))
42528-
return TLO.CombineTo(Op, Op.getOperand(1));
42671+
case X86ISD::PCMPGT: {
42672+
SDLoc DL(Op);
42673+
if (SDValue R = simplifySExtOfDecomposedSetCC(
42674+
TLO.DAG, DL, ISD::SETGT, Op.getOperand(0), Op.getOperand(1),
42675+
OriginalDemandedBits, OriginalDemandedElts,
42676+
/*AllowNOT*/ true, Depth))
42677+
return TLO.CombineTo(Op, R);
42678+
break;
42679+
}
42680+
case X86ISD::CMPP: {
42681+
SDLoc DL(Op);
42682+
ISD::CondCode CC = X86::getCondForCMPPImm(
42683+
cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue());
42684+
if (SDValue R = simplifySExtOfDecomposedSetCC(
42685+
TLO.DAG, DL, CC, Op.getOperand(0), Op.getOperand(1),
42686+
OriginalDemandedBits, OriginalDemandedElts,
42687+
!(TLO.LegalOperations() && TLO.LegalTypes()), Depth))
42688+
return TLO.CombineTo(Op, R);
4252942689
break;
42690+
}
4253042691
case X86ISD::MOVMSK: {
4253142692
SDValue Src = Op.getOperand(0);
4253242693
MVT SrcVT = Src.getSimpleValueType();
@@ -42710,13 +42871,25 @@ SDValue X86TargetLowering::SimplifyMultipleUseDemandedBitsForTargetNode(
4271042871
if (DemandedBits.isSignMask())
4271142872
return Op.getOperand(0);
4271242873
break;
42713-
case X86ISD::PCMPGT:
42714-
// icmp sgt(0, R) == ashr(R, BitWidth-1).
42715-
// iff we only need the sign bit then we can use R directly.
42716-
if (DemandedBits.isSignMask() &&
42717-
ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode()))
42718-
return Op.getOperand(1);
42874+
case X86ISD::PCMPGT: {
42875+
SDLoc DL(Op);
42876+
if (SDValue R = simplifySExtOfDecomposedSetCC(
42877+
DAG, DL, ISD::SETGT, Op.getOperand(0), Op.getOperand(1),
42878+
DemandedBits, DemandedElts, /*AllowNOT*/ false, Depth))
42879+
return R;
42880+
break;
42881+
}
42882+
case X86ISD::CMPP: {
42883+
SDLoc DL(Op);
42884+
ISD::CondCode CC = X86::getCondForCMPPImm(
42885+
cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue());
42886+
if (SDValue R = simplifySExtOfDecomposedSetCC(DAG, DL, CC, Op.getOperand(0),
42887+
Op.getOperand(1),
42888+
DemandedBits, DemandedElts,
42889+
/*AllowNOT*/ false, Depth))
42890+
return R;
4271942891
break;
42892+
}
4272042893
case X86ISD::BLENDV: {
4272142894
// BLENDV: Cond (MSB) ? LHS : RHS
4272242895
SDValue Cond = Op.getOperand(0);
@@ -48392,7 +48565,7 @@ static SDValue combineAndShuffleNot(SDNode *N, SelectionDAG &DAG,
4839248565

4839348566
// We do not split for SSE at all, but we need to split vectors for AVX1 and
4839448567
// AVX2.
48395-
if (!Subtarget.useAVX512Regs() && VT.is512BitVector() &&
48568+
if (!Subtarget.useAVX512Regs() && VT.is512BitVector() &&
4839648569
TLI.isTypeLegal(VT.getHalfNumVectorElementsVT(*DAG.getContext()))) {
4839748570
SDValue LoX, HiX;
4839848571
std::tie(LoX, HiX) = splitVector(X, DAG, DL);

llvm/lib/Target/X86/X86InstrInfo.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3360,6 +3360,46 @@ unsigned X86::getVPCMPImmForCond(ISD::CondCode CC) {
33603360
}
33613361
}
33623362

3363+
ISD::CondCode X86::getCondForCMPPImm(unsigned Imm) {
3364+
assert(Imm <= 0x1f && "Invalid CMPP Imm");
3365+
switch (Imm & 0xf) {
3366+
default:
3367+
llvm_unreachable("Invalid CMPP Imm");
3368+
case 0:
3369+
return ISD::SETOEQ;
3370+
case 1:
3371+
return ISD::SETOLT;
3372+
case 2:
3373+
return ISD::SETOLE;
3374+
case 3:
3375+
return ISD::SETUO;
3376+
case 4:
3377+
return ISD::SETUNE;
3378+
case 5:
3379+
return ISD::SETUGE;
3380+
case 6:
3381+
return ISD::SETUGT;
3382+
case 7:
3383+
return ISD::SETO;
3384+
case 8:
3385+
return ISD::SETUEQ;
3386+
case 9:
3387+
return ISD::SETULT;
3388+
case 10:
3389+
return ISD::SETULE;
3390+
case 11:
3391+
return ISD::SETFALSE;
3392+
case 12:
3393+
return ISD::SETONE;
3394+
case 13:
3395+
return ISD::SETOGE;
3396+
case 14:
3397+
return ISD::SETOGT;
3398+
case 15:
3399+
return ISD::SETTRUE;
3400+
}
3401+
}
3402+
33633403
/// Get the VPCMP immediate if the operands are swapped.
33643404
unsigned X86::getSwappedVPCMPImm(unsigned Imm) {
33653405
switch (Imm) {

llvm/lib/Target/X86/X86InstrInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ CondCode GetOppositeBranchCondition(CondCode CC);
7272
/// Get the VPCMP immediate for the given condition.
7373
unsigned getVPCMPImmForCond(ISD::CondCode CC);
7474

75+
/// Get the CondCode from a CMPP immediate.
76+
ISD::CondCode getCondForCMPPImm(unsigned Imm);
77+
7578
/// Get the VPCMP immediate if the opcodes are swapped.
7679
unsigned getSwappedVPCMPImm(unsigned Imm);
7780

llvm/test/CodeGen/X86/combine-testps.ll

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -171,24 +171,13 @@ define i32 @testpsz_128_signbit(<4 x float> %c, <4 x float> %d, i32 %a, i32 %b)
171171
}
172172

173173
define i32 @testpsnzc_256_signbit(<8 x float> %c, <8 x float> %d, i32 %a, i32 %b) {
174-
; AVX-LABEL: testpsnzc_256_signbit:
175-
; AVX: # %bb.0:
176-
; AVX-NEXT: movl %edi, %eax
177-
; AVX-NEXT: vcvtdq2ps %ymm0, %ymm0
178-
; AVX-NEXT: vxorps %xmm2, %xmm2, %xmm2
179-
; AVX-NEXT: vcmpltps %ymm2, %ymm0, %ymm0
180-
; AVX-NEXT: vtestps %ymm1, %ymm0
181-
; AVX-NEXT: cmovnel %esi, %eax
182-
; AVX-NEXT: vzeroupper
183-
; AVX-NEXT: retq
184-
;
185-
; AVX2-LABEL: testpsnzc_256_signbit:
186-
; AVX2: # %bb.0:
187-
; AVX2-NEXT: movl %edi, %eax
188-
; AVX2-NEXT: vtestps %ymm1, %ymm0
189-
; AVX2-NEXT: cmovnel %esi, %eax
190-
; AVX2-NEXT: vzeroupper
191-
; AVX2-NEXT: retq
174+
; CHECK-LABEL: testpsnzc_256_signbit:
175+
; CHECK: # %bb.0:
176+
; CHECK-NEXT: movl %edi, %eax
177+
; CHECK-NEXT: vtestps %ymm1, %ymm0
178+
; CHECK-NEXT: cmovnel %esi, %eax
179+
; CHECK-NEXT: vzeroupper
180+
; CHECK-NEXT: retq
192181
%t0 = bitcast <8 x float> %c to <8 x i32>
193182
%t1 = icmp sgt <8 x i32> zeroinitializer, %t0
194183
%t2 = sext <8 x i1> %t1 to <8 x i32>

llvm/test/CodeGen/X86/sadd_sat_vec.ll

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -861,9 +861,6 @@ define <8 x i32> @v8i32(<8 x i32> %x, <8 x i32> %y) nounwind {
861861
; AVX1-NEXT: vpcmpgtd %xmm2, %xmm3, %xmm3
862862
; AVX1-NEXT: vpcmpgtd %xmm4, %xmm0, %xmm0
863863
; AVX1-NEXT: vinsertf128 $1, %xmm3, %ymm0, %ymm0
864-
; AVX1-NEXT: vcvtdq2ps %ymm1, %ymm1
865-
; AVX1-NEXT: vpxor %xmm3, %xmm3, %xmm3
866-
; AVX1-NEXT: vcmpltps %ymm3, %ymm1, %ymm1
867864
; AVX1-NEXT: vxorps %ymm0, %ymm1, %ymm0
868865
; AVX1-NEXT: vpsrad $31, %xmm4, %xmm1
869866
; AVX1-NEXT: vpsrad $31, %xmm2, %xmm2
@@ -1062,9 +1059,6 @@ define <16 x i32> @v16i32(<16 x i32> %x, <16 x i32> %y) nounwind {
10621059
; AVX1-NEXT: vpcmpgtd %xmm4, %xmm5, %xmm5
10631060
; AVX1-NEXT: vpcmpgtd %xmm6, %xmm0, %xmm0
10641061
; AVX1-NEXT: vinsertf128 $1, %xmm5, %ymm0, %ymm0
1065-
; AVX1-NEXT: vcvtdq2ps %ymm2, %ymm2
1066-
; AVX1-NEXT: vpxor %xmm5, %xmm5, %xmm5
1067-
; AVX1-NEXT: vcmpltps %ymm5, %ymm2, %ymm2
10681062
; AVX1-NEXT: vxorps %ymm0, %ymm2, %ymm0
10691063
; AVX1-NEXT: vpsrad $31, %xmm6, %xmm2
10701064
; AVX1-NEXT: vpsrad $31, %xmm4, %xmm4
@@ -1073,21 +1067,19 @@ define <16 x i32> @v16i32(<16 x i32> %x, <16 x i32> %y) nounwind {
10731067
; AVX1-NEXT: vxorps %ymm4, %ymm2, %ymm2
10741068
; AVX1-NEXT: vblendvps %ymm0, %ymm2, %ymm7, %ymm0
10751069
; AVX1-NEXT: vextractf128 $1, %ymm3, %xmm2
1076-
; AVX1-NEXT: vextractf128 $1, %ymm1, %xmm6
1077-
; AVX1-NEXT: vpaddd %xmm2, %xmm6, %xmm2
1078-
; AVX1-NEXT: vpaddd %xmm3, %xmm1, %xmm7
1079-
; AVX1-NEXT: vinsertf128 $1, %xmm2, %ymm7, %ymm8
1080-
; AVX1-NEXT: vpcmpgtd %xmm2, %xmm6, %xmm6
1081-
; AVX1-NEXT: vpcmpgtd %xmm7, %xmm1, %xmm1
1082-
; AVX1-NEXT: vinsertf128 $1, %xmm6, %ymm1, %ymm1
1083-
; AVX1-NEXT: vcvtdq2ps %ymm3, %ymm3
1084-
; AVX1-NEXT: vcmpltps %ymm5, %ymm3, %ymm3
1070+
; AVX1-NEXT: vextractf128 $1, %ymm1, %xmm5
1071+
; AVX1-NEXT: vpaddd %xmm2, %xmm5, %xmm2
1072+
; AVX1-NEXT: vpaddd %xmm3, %xmm1, %xmm6
1073+
; AVX1-NEXT: vinsertf128 $1, %xmm2, %ymm6, %ymm7
1074+
; AVX1-NEXT: vpcmpgtd %xmm2, %xmm5, %xmm5
1075+
; AVX1-NEXT: vpcmpgtd %xmm6, %xmm1, %xmm1
1076+
; AVX1-NEXT: vinsertf128 $1, %xmm5, %ymm1, %ymm1
10851077
; AVX1-NEXT: vxorps %ymm1, %ymm3, %ymm1
1086-
; AVX1-NEXT: vpsrad $31, %xmm7, %xmm3
1078+
; AVX1-NEXT: vpsrad $31, %xmm6, %xmm3
10871079
; AVX1-NEXT: vpsrad $31, %xmm2, %xmm2
10881080
; AVX1-NEXT: vinsertf128 $1, %xmm2, %ymm3, %ymm2
10891081
; AVX1-NEXT: vxorps %ymm4, %ymm2, %ymm2
1090-
; AVX1-NEXT: vblendvps %ymm1, %ymm2, %ymm8, %ymm1
1082+
; AVX1-NEXT: vblendvps %ymm1, %ymm2, %ymm7, %ymm1
10911083
; AVX1-NEXT: retq
10921084
;
10931085
; AVX2-LABEL: v16i32:

0 commit comments

Comments
 (0)