Skip to content

Commit bf63272

Browse files
committed
[X86] Try Folding icmp of v8i32 -> fcmp of v8f32 on AVX
Fixes: #82242 The idea is that AVX doesn't support comparisons for `v8i32` so it splits the comparison into 2x `v4i32` comparisons + reconstruction of the `v8i32`. By converting to a float, we can handle the comparison with 1/2 instructions (1 if we can `bitcast`, 2 if we need to cast with `sitofp`). The Proofs: https://alive2.llvm.org/ce/z/AJDdQ8 Timeout, but they can be reproduced locally.
1 parent c0d5e32 commit bf63272

33 files changed

+2335
-2289
lines changed

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "llvm/CodeGen/TargetLowering.h"
14+
#include "llvm/ADT/APFloat.h"
1415
#include "llvm/ADT/STLExtras.h"
1516
#include "llvm/Analysis/VectorUtils.h"
1617
#include "llvm/CodeGen/CallingConvLower.h"
@@ -21,6 +22,7 @@
2122
#include "llvm/CodeGen/MachineModuleInfoImpls.h"
2223
#include "llvm/CodeGen/MachineRegisterInfo.h"
2324
#include "llvm/CodeGen/SelectionDAG.h"
25+
#include "llvm/CodeGen/SelectionDAGNodes.h"
2426
#include "llvm/CodeGen/TargetRegisterInfo.h"
2527
#include "llvm/IR/DataLayout.h"
2628
#include "llvm/IR/DerivedTypes.h"
@@ -936,6 +938,7 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedVectorElts(
936938
Depth);
937939
}
938940

941+
939942
// Attempt to form ext(avgfloor(A, B)) from shr(add(ext(A), ext(B)), 1).
940943
// or to form ext(avgceil(A, B)) from shr(add(ext(A), ext(B), 1), 1).
941944
static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
@@ -2471,6 +2474,8 @@ bool TargetLowering::SimplifyDemandedBits(
24712474
if (SimplifyDemandedBits(Src, InDemandedBits, InDemandedElts, Known, TLO,
24722475
Depth + 1))
24732476
return true;
2477+
2478+
24742479
assert(!Known.hasConflict() && "Bits known to be one AND zero?");
24752480
assert(Known.getBitWidth() == InBits && "Src width has changed?");
24762481

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 274 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23299,6 +23299,126 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
2329923299
}
2330023300
}
2330123301

23302+
// We get bad codegen for v8i32 compares on avx targets (without avx2) so if
23303+
// possible convert to a v8f32 compare.
23304+
if (VTOp0 == MVT::v8i32 && Subtarget.hasAVX() && !Subtarget.hasAVX2()) {
23305+
std::optional<KnownBits> KnownOps[2];
23306+
// Check if an op is known to be in a certain range.
23307+
auto OpInRange = [&DAG, Op, &KnownOps](unsigned OpNo, bool CmpLT,
23308+
const APInt Bound) {
23309+
if (!KnownOps[OpNo].has_value())
23310+
KnownOps[OpNo] = DAG.computeKnownBits(Op.getOperand(OpNo));
23311+
23312+
if (KnownOps[OpNo]->isUnknown())
23313+
return false;
23314+
23315+
std::optional<bool> Res;
23316+
if (CmpLT)
23317+
Res = KnownBits::ult(*KnownOps[OpNo], KnownBits::makeConstant(Bound));
23318+
else
23319+
Res = KnownBits::ugt(*KnownOps[OpNo], KnownBits::makeConstant(Bound));
23320+
return Res.has_value() && *Res;
23321+
};
23322+
23323+
bool OkayCvt = false;
23324+
bool OkayBitcast = false;
23325+
23326+
const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(MVT::f32);
23327+
23328+
// For cvt up to 1 << (Significand Precision), (1 << 24 for ieee float)
23329+
const APInt MaxConvertableCvt =
23330+
APInt(32, (1U << APFloat::semanticsPrecision(Sem)));
23331+
// For bitcast up to (and including) first inf representation (0x7f800000 +
23332+
// 1 for ieee float)
23333+
const APInt MaxConvertableBitcast =
23334+
APFloat::getInf(Sem).bitcastToAPInt() + 1;
23335+
23336+
assert(
23337+
MaxConvertableBitcast.getBitWidth() == 32 &&
23338+
MaxConvertableCvt == (1U << 24) &&
23339+
MaxConvertableBitcast == 0x7f800001 &&
23340+
"This transform has only been verified to IEEE Single Precision Float");
23341+
23342+
// For bitcast we need both lhs/op1 u< MaxConvertableBitcast
23343+
// NB: It might be worth it to enable to bitcast version for unsigned avx2
23344+
// comparisons as they typically require multiple instructions to lower
23345+
// (they don't fit `vpcmpeq`/`vpcmpgt` well).
23346+
if (OpInRange(1, /*CmpLT*/ true, MaxConvertableBitcast) &&
23347+
OpInRange(0, /*CmpLT*/ true, MaxConvertableBitcast)) {
23348+
OkayBitcast = true;
23349+
}
23350+
// We want to convert icmp -> fcmp using `sitofp` iff one of the converts
23351+
// will be constant folded.
23352+
else if ((DAG.isConstantValueOfAnyType(peekThroughBitcasts(Op1)) ||
23353+
DAG.isConstantValueOfAnyType(peekThroughBitcasts(Op0)))) {
23354+
if (isUnsignedIntSetCC(Cond)) {
23355+
// For cvt + unsigned compare we need both lhs/rhs >= 0 and either lhs
23356+
// or rhs < MaxConvertableCvt
23357+
23358+
if (OpInRange(1, /*CmpLT*/ true, APInt::getSignedMinValue(32)) &&
23359+
OpInRange(0, /*CmpLT*/ true, APInt::getSignedMinValue(32)) &&
23360+
(OpInRange(1, /*CmpLT*/ true, MaxConvertableCvt) ||
23361+
OpInRange(0, /*CmpLT*/ true, MaxConvertableCvt)))
23362+
OkayCvt = true;
23363+
} else {
23364+
// For cvt + signed compare we need abs(lhs) or abs(rhs) <
23365+
// MaxConvertableCvt
23366+
if (OpInRange(1, /*CmpLT*/ true, MaxConvertableCvt) ||
23367+
OpInRange(1, /*CmpLT*/ false, -MaxConvertableCvt) ||
23368+
OpInRange(0, /*CmpLT*/ true, MaxConvertableCvt) ||
23369+
OpInRange(0, /*CmpLT*/ false, -MaxConvertableCvt))
23370+
OkayCvt = true;
23371+
}
23372+
}
23373+
23374+
if (OkayBitcast || OkayCvt) {
23375+
switch (Cond) {
23376+
default:
23377+
llvm_unreachable("Unexpected SETCC condition");
23378+
// Get the new FP condition. Note for the unsigned conditions we have
23379+
// verified its okay to convert to the signed version.
23380+
case ISD::SETULT:
23381+
case ISD::SETLT:
23382+
Cond = ISD::SETOLT;
23383+
break;
23384+
case ISD::SETUGT:
23385+
case ISD::SETGT:
23386+
Cond = ISD::SETOGT;
23387+
break;
23388+
case ISD::SETULE:
23389+
case ISD::SETLE:
23390+
Cond = ISD::SETOLE;
23391+
break;
23392+
case ISD::SETUGE:
23393+
case ISD::SETGE:
23394+
Cond = ISD::SETOGE;
23395+
break;
23396+
case ISD::SETEQ:
23397+
Cond = ISD::SETOEQ;
23398+
break;
23399+
case ISD::SETNE:
23400+
Cond = ISD::SETONE;
23401+
break;
23402+
}
23403+
23404+
MVT FpVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
23405+
SDNodeFlags Flags;
23406+
Flags.setNoNaNs(true);
23407+
Flags.setNoInfs(true);
23408+
Flags.setNoSignedZeros(true);
23409+
if (OkayBitcast) {
23410+
Op0 = DAG.getBitcast(FpVT, Op0);
23411+
Op1 = DAG.getBitcast(FpVT, Op1);
23412+
} else {
23413+
Op0 = DAG.getNode(ISD::SINT_TO_FP, dl, FpVT, Op0);
23414+
Op1 = DAG.getNode(ISD::SINT_TO_FP, dl, FpVT, Op1);
23415+
}
23416+
Op0->setFlags(Flags);
23417+
Op1->setFlags(Flags);
23418+
return DAG.getSetCC(dl, VT, Op0, Op1, Cond);
23419+
}
23420+
}
23421+
2330223422
// Break 256-bit integer vector compare into smaller ones.
2330323423
if (VT.is256BitVector() && !Subtarget.hasInt256())
2330423424
return splitIntVSETCC(VT, Op0, Op1, Cond, DAG, dl);
@@ -41037,6 +41157,126 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG,
4103741157
return SDValue();
4103841158
}
4103941159

41160+
// Simplify a decomposed (sext (setcc)). Assumes prior check that
41161+
// bitwidth(sext)==bitwidth(setcc operands).
41162+
static SDValue simplifySExtOfDecomposedSetCCImpl(
41163+
SelectionDAG &DAG, SDLoc &DL, ISD::CondCode CC, SDValue Op0, SDValue Op1,
41164+
const APInt &OriginalDemandedBits, bool AllowNOT) {
41165+
// Possible TODO: We could handle any power of two demanded bit + unsigned
41166+
// comparison. There are no x86 specific comparisons that are unsigned so its
41167+
// unneeded.
41168+
if (!OriginalDemandedBits.isSignMask())
41169+
return SDValue();
41170+
41171+
EVT OpVT = Op0.getValueType();
41172+
// We need need nofpclass(nan inf nzero) to handle floats.
41173+
auto hasOkayFPFlags = [](SDValue Op) {
41174+
return Op->getFlags().hasNoNaNs() && Op->getFlags().hasNoInfs() &&
41175+
Op->getFlags().hasNoSignedZeros();
41176+
};
41177+
41178+
if (OpVT.isFloatingPoint() && !hasOkayFPFlags(Op0))
41179+
return SDValue();
41180+
41181+
auto ValsEq = [OpVT](const APInt &V0, APInt V1) -> bool {
41182+
if (OpVT.isFloatingPoint()) {
41183+
const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(OpVT);
41184+
return V0.eq(APFloat(Sem, V1).bitcastToAPInt());
41185+
}
41186+
return V0.eq(V1);
41187+
};
41188+
41189+
// Assume we canonicalized constants to Op1. That isn't always true but we
41190+
// call this function twice with inverted CC/Operands so its fine either way.
41191+
APInt Op1C;
41192+
unsigned ValWidth = OriginalDemandedBits.getBitWidth();
41193+
if (ISD::isConstantSplatVectorAllZeros(Op1.getNode())) {
41194+
Op1C = APInt::getZero(ValWidth);
41195+
} else if (ISD::isConstantSplatVectorAllOnes(Op1.getNode())) {
41196+
Op1C = APInt::getAllOnes(ValWidth);
41197+
} else if (auto *C = dyn_cast<ConstantFPSDNode>(Op1)) {
41198+
Op1C = C->getValueAPF().bitcastToAPInt();
41199+
} else if (auto *C = dyn_cast<ConstantSDNode>(Op1)) {
41200+
Op1C = C->getAPIntValue();
41201+
} else if (ISD::isConstantSplatVector(Op1.getNode(), Op1C)) {
41202+
// Pass
41203+
} else {
41204+
return SDValue();
41205+
}
41206+
41207+
bool Not = false;
41208+
bool Okay = false;
41209+
assert(OriginalDemandedBits.getBitWidth() == Op1C.getBitWidth() &&
41210+
"Invalid constant operand");
41211+
41212+
switch (CC) {
41213+
case ISD::SETGE:
41214+
case ISD::SETOGE:
41215+
Not = true;
41216+
[[fallthrough]];
41217+
case ISD::SETLT:
41218+
case ISD::SETOLT:
41219+
// signbit(sext(x s< 0)) == signbit(x)
41220+
// signbit(sext(x s>= 0)) == signbit(~x)
41221+
Okay = ValsEq(Op1C, APInt::getZero(ValWidth));
41222+
break;
41223+
case ISD::SETGT:
41224+
case ISD::SETOGT:
41225+
Not = true;
41226+
[[fallthrough]];
41227+
case ISD::SETLE:
41228+
case ISD::SETOLE:
41229+
// signbit(sext(x s<= -1)) == signbit(x)
41230+
// signbit(sext(x s> -1)) == signbit(~x)
41231+
Okay = ValsEq(Op1C, APInt::getAllOnes(ValWidth));
41232+
break;
41233+
case ISD::SETULT:
41234+
Not = true;
41235+
[[fallthrough]];
41236+
case ISD::SETUGE:
41237+
// signbit(sext(x u>= SIGNED_MIN)) == signbit(x)
41238+
// signbit(sext(x u< SIGNED_MIN)) == signbit(~x)
41239+
Okay = ValsEq(Op1C, OriginalDemandedBits);
41240+
break;
41241+
case ISD::SETULE:
41242+
Not = true;
41243+
[[fallthrough]];
41244+
case ISD::SETUGT:
41245+
// signbit(sext(x u> SIGNED_MAX)) == signbit(x)
41246+
// signbit(sext(x u<= SIGNED_MAX)) == signbit(~x)
41247+
Okay = ValsEq(Op1C, OriginalDemandedBits - 1);
41248+
break;
41249+
default:
41250+
break;
41251+
}
41252+
41253+
Okay = Not ? AllowNOT : Okay;
41254+
if (!Okay)
41255+
return SDValue();
41256+
41257+
if (!Not)
41258+
return Op0;
41259+
41260+
if (!OpVT.isFloatingPoint())
41261+
return DAG.getNOT(DL, Op0, OpVT);
41262+
41263+
// Possible TODO: We could use `fneg` to do not.
41264+
return SDValue();
41265+
}
41266+
41267+
static SDValue simplifySExtOfDecomposedSetCC(SelectionDAG &DAG, SDLoc &DL,
41268+
ISD::CondCode CC, SDValue Op0,
41269+
SDValue Op1,
41270+
const APInt &OriginalDemandedBits,
41271+
bool AllowNOT) {
41272+
if (SDValue R = simplifySExtOfDecomposedSetCCImpl(
41273+
DAG, DL, CC, Op0, Op1, OriginalDemandedBits, AllowNOT))
41274+
return R;
41275+
return simplifySExtOfDecomposedSetCCImpl(
41276+
DAG, DL, ISD::getSetCCSwappedOperands(CC), Op1, Op0, OriginalDemandedBits,
41277+
AllowNOT);
41278+
}
41279+
4104041280
// Simplify variable target shuffle masks based on the demanded elements.
4104141281
// TODO: Handle DemandedBits in mask indices as well?
4104241282
bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetShuffle(
@@ -42200,13 +42440,24 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
4220042440
}
4220142441
break;
4220242442
}
42203-
case X86ISD::PCMPGT:
42204-
// icmp sgt(0, R) == ashr(R, BitWidth-1).
42205-
// iff we only need the sign bit then we can use R directly.
42206-
if (OriginalDemandedBits.isSignMask() &&
42207-
ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode()))
42208-
return TLO.CombineTo(Op, Op.getOperand(1));
42443+
case X86ISD::PCMPGT: {
42444+
SDLoc DL(Op);
42445+
if (SDValue R = simplifySExtOfDecomposedSetCC(
42446+
TLO.DAG, DL, ISD::SETGT, Op.getOperand(0), Op.getOperand(1),
42447+
OriginalDemandedBits, !(TLO.LegalOperations() && TLO.LegalTypes())))
42448+
return TLO.CombineTo(Op, R);
42449+
break;
42450+
}
42451+
case X86ISD::CMPP: {
42452+
SDLoc DL(Op);
42453+
ISD::CondCode CC = X86::getCondForCMPPImm(
42454+
cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue());
42455+
if (SDValue R = simplifySExtOfDecomposedSetCC(
42456+
TLO.DAG, DL, CC, Op.getOperand(0), Op.getOperand(1),
42457+
OriginalDemandedBits, !(TLO.LegalOperations() && TLO.LegalTypes())))
42458+
return TLO.CombineTo(Op, R);
4220942459
break;
42460+
}
4221042461
case X86ISD::MOVMSK: {
4221142462
SDValue Src = Op.getOperand(0);
4221242463
MVT SrcVT = Src.getSimpleValueType();
@@ -42390,13 +42641,24 @@ SDValue X86TargetLowering::SimplifyMultipleUseDemandedBitsForTargetNode(
4239042641
if (DemandedBits.isSignMask())
4239142642
return Op.getOperand(0);
4239242643
break;
42393-
case X86ISD::PCMPGT:
42394-
// icmp sgt(0, R) == ashr(R, BitWidth-1).
42395-
// iff we only need the sign bit then we can use R directly.
42396-
if (DemandedBits.isSignMask() &&
42397-
ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode()))
42398-
return Op.getOperand(1);
42644+
case X86ISD::PCMPGT: {
42645+
SDLoc DL(Op);
42646+
if (SDValue R = simplifySExtOfDecomposedSetCC(
42647+
DAG, DL, ISD::SETGT, Op.getOperand(0), Op.getOperand(1),
42648+
DemandedBits, /*AllowNOT*/ false))
42649+
return R;
42650+
break;
42651+
}
42652+
case X86ISD::CMPP: {
42653+
SDLoc DL(Op);
42654+
ISD::CondCode CC = X86::getCondForCMPPImm(
42655+
cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue());
42656+
if (SDValue R = simplifySExtOfDecomposedSetCC(
42657+
DAG, DL, CC, Op.getOperand(0), Op.getOperand(1), DemandedBits,
42658+
/*AllowNOT*/ false))
42659+
return R;
4239942660
break;
42661+
}
4240042662
case X86ISD::BLENDV: {
4240142663
// BLENDV: Cond (MSB) ? LHS : RHS
4240242664
SDValue Cond = Op.getOperand(0);

llvm/lib/Target/X86/X86InstrInfo.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3349,6 +3349,46 @@ unsigned X86::getVPCMPImmForCond(ISD::CondCode CC) {
33493349
}
33503350
}
33513351

3352+
ISD::CondCode X86::getCondForCMPPImm(unsigned Imm) {
3353+
assert(Imm <= 0x1f && "Invalid CMPP Imm");
3354+
switch (Imm & 0xf) {
3355+
default:
3356+
llvm_unreachable("Invalid CMPP Imm");
3357+
case 0:
3358+
return ISD::SETOEQ;
3359+
case 1:
3360+
return ISD::SETOLT;
3361+
case 2:
3362+
return ISD::SETOLE;
3363+
case 3:
3364+
return ISD::SETUO;
3365+
case 4:
3366+
return ISD::SETUNE;
3367+
case 5:
3368+
return ISD::SETUGE;
3369+
case 6:
3370+
return ISD::SETUGT;
3371+
case 7:
3372+
return ISD::SETO;
3373+
case 8:
3374+
return ISD::SETUEQ;
3375+
case 9:
3376+
return ISD::SETULT;
3377+
case 10:
3378+
return ISD::SETULE;
3379+
case 11:
3380+
return ISD::SETFALSE;
3381+
case 12:
3382+
return ISD::SETONE;
3383+
case 13:
3384+
return ISD::SETOGE;
3385+
case 14:
3386+
return ISD::SETOGT;
3387+
case 15:
3388+
return ISD::SETTRUE;
3389+
}
3390+
}
3391+
33523392
/// Get the VPCMP immediate if the operands are swapped.
33533393
unsigned X86::getSwappedVPCMPImm(unsigned Imm) {
33543394
switch (Imm) {

0 commit comments

Comments
 (0)