-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[X86] Improve helper for simplifying demanded bits of compares #84360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23391,6 +23391,136 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget, | |
} | ||
} | ||
|
||
// We get bad codegen for v8i32 compares on avx targets (without avx2) so if | ||
// possible convert to a v8f32 compare. | ||
if (VTOp0 == MVT::v8i32 && Subtarget.hasAVX() && !Subtarget.hasAVX2()) { | ||
std::optional<KnownBits> KnownOps[2]; | ||
// Check if an op is known to be in a certain range. | ||
auto OpInRange = [&DAG, Op, &KnownOps](unsigned OpNo, bool CmpLT, | ||
const APInt Bound) { | ||
if (!KnownOps[OpNo].has_value()) | ||
KnownOps[OpNo] = DAG.computeKnownBits(Op.getOperand(OpNo)); | ||
|
||
if (KnownOps[OpNo]->isUnknown()) | ||
return false; | ||
|
||
std::optional<bool> Res; | ||
if (CmpLT) | ||
Res = KnownBits::ult(*KnownOps[OpNo], KnownBits::makeConstant(Bound)); | ||
else | ||
Res = KnownBits::ugt(*KnownOps[OpNo], KnownBits::makeConstant(Bound)); | ||
return Res.value_or(false); | ||
}; | ||
|
||
bool OkayCvt = false; | ||
bool OkayBitcast = false; | ||
|
||
const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(MVT::f32); | ||
|
||
// For cvt up to 1 << (Significand Precision), (1 << 24 for ieee float) | ||
const APInt MaxConvertableCvt = | ||
APInt::getOneBitSet(32, APFloat::semanticsPrecision(Sem)); | ||
// For bitcast up to (and including) first inf representation (0x7f800000 + | ||
// 1 for ieee float) | ||
const APInt MaxConvertableBitcast = | ||
APFloat::getInf(Sem).bitcastToAPInt() + 1; | ||
// For bitcast we also exclude de-norm values. This is absolutely necessary | ||
// for strict semantic correctness, but DAZ (de-norm as zero) will break if | ||
// we don't have this check. | ||
const APInt MinConvertableBitcast = | ||
APFloat::getSmallestNormalized(Sem).bitcastToAPInt() - 1; | ||
|
||
assert( | ||
MaxConvertableBitcast.getBitWidth() == 32 && | ||
MaxConvertableCvt == (1U << 24) && | ||
MaxConvertableBitcast == 0x7f800001 && | ||
MinConvertableBitcast.isNonNegative() && | ||
MaxConvertableBitcast.sgt(MinConvertableBitcast) && | ||
"This transform has only been verified to IEEE Single Precision Float"); | ||
|
||
// For bitcast we need both lhs/op1 u< MaxConvertableBitcast | ||
// NB: It might be worth it to enable to bitcast version for unsigned avx2 | ||
// comparisons as they typically require multiple instructions to lower | ||
// (they don't fit `vpcmpeq`/`vpcmpgt` well). | ||
if (OpInRange(1, /*CmpLT*/ true, MaxConvertableBitcast) && | ||
OpInRange(1, /*CmpLT*/ false, MinConvertableBitcast) && | ||
OpInRange(0, /*CmpLT*/ true, MaxConvertableBitcast) && | ||
OpInRange(0, /*CmpLT*/ false, MinConvertableBitcast)) { | ||
OkayBitcast = true; | ||
} | ||
// We want to convert icmp -> fcmp using `sitofp` iff one of the converts | ||
// will be constant folded. | ||
else if ((DAG.isConstantValueOfAnyType(peekThroughBitcasts(Op1)) || | ||
DAG.isConstantValueOfAnyType(peekThroughBitcasts(Op0)))) { | ||
if (isUnsignedIntSetCC(Cond)) { | ||
// For cvt + unsigned compare we need both lhs/rhs >= 0 and either lhs | ||
// or rhs < MaxConvertableCvt | ||
|
||
if (OpInRange(1, /*CmpLT*/ true, APInt::getSignedMinValue(32)) && | ||
OpInRange(0, /*CmpLT*/ true, APInt::getSignedMinValue(32)) && | ||
(OpInRange(1, /*CmpLT*/ true, MaxConvertableCvt) || | ||
OpInRange(0, /*CmpLT*/ true, MaxConvertableCvt))) | ||
OkayCvt = true; | ||
} else { | ||
// For cvt + signed compare we need abs(lhs) or abs(rhs) < | ||
// MaxConvertableCvt | ||
if (OpInRange(1, /*CmpLT*/ true, MaxConvertableCvt) || | ||
OpInRange(1, /*CmpLT*/ false, -MaxConvertableCvt) || | ||
OpInRange(0, /*CmpLT*/ true, MaxConvertableCvt) || | ||
OpInRange(0, /*CmpLT*/ false, -MaxConvertableCvt)) | ||
OkayCvt = true; | ||
} | ||
} | ||
// TODO: If we can't prove any of the ranges, we could unconditionally lower | ||
// `(icmp eq lhs, rhs)` as `(icmp eq (int_to_fp (xor lhs, rhs)), zero)` | ||
if (OkayBitcast || OkayCvt) { | ||
switch (Cond) { | ||
default: | ||
llvm_unreachable("Unexpected SETCC condition"); | ||
// Get the new FP condition. Note for the unsigned conditions we have | ||
// verified its okay to convert to the signed version. | ||
case ISD::SETULT: | ||
case ISD::SETLT: | ||
Cond = ISD::SETOLT; | ||
break; | ||
case ISD::SETUGT: | ||
case ISD::SETGT: | ||
Cond = ISD::SETOGT; | ||
break; | ||
case ISD::SETULE: | ||
case ISD::SETLE: | ||
Cond = ISD::SETOLE; | ||
break; | ||
case ISD::SETUGE: | ||
case ISD::SETGE: | ||
Cond = ISD::SETOGE; | ||
break; | ||
case ISD::SETEQ: | ||
Cond = ISD::SETOEQ; | ||
break; | ||
case ISD::SETNE: | ||
Cond = ISD::SETONE; | ||
break; | ||
} | ||
|
||
MVT FpVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount()); | ||
SDNodeFlags Flags; | ||
Flags.setNoNaNs(true); | ||
Flags.setNoInfs(true); | ||
Flags.setNoSignedZeros(true); | ||
if (OkayBitcast) { | ||
Op0 = DAG.getBitcast(FpVT, Op0); | ||
Op1 = DAG.getBitcast(FpVT, Op1); | ||
} else { | ||
Op0 = DAG.getNode(ISD::SINT_TO_FP, dl, FpVT, Op0); | ||
Op1 = DAG.getNode(ISD::SINT_TO_FP, dl, FpVT, Op1); | ||
} | ||
Op0->setFlags(Flags); | ||
Op1->setFlags(Flags); | ||
return DAG.getSetCC(dl, VT, Op0, Op1, Cond); | ||
} | ||
} | ||
|
||
// Break 256-bit integer vector compare into smaller ones. | ||
if (VT.is256BitVector() && !Subtarget.hasInt256()) | ||
return splitIntVSETCC(VT, Op0, Op1, Cond, DAG, dl); | ||
|
@@ -41216,6 +41346,156 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG, | |
return SDValue(); | ||
} | ||
|
||
// Simplify a decomposed (sext (setcc)). Assumes prior check that | ||
// bitwidth(sext)==bitwidth(setcc operands). | ||
static SDValue simplifySExtOfDecomposedSetCCImpl( | ||
SelectionDAG &DAG, const SDLoc &DL, ISD::CondCode CC, SDValue Op0, | ||
SDValue Op1, const APInt &OriginalDemandedBits, | ||
const APInt &OriginalDemandedElts, bool AllowNOT, unsigned Depth) { | ||
// Possible TODO: We could handle any power of two demanded bit + unsigned | ||
// comparison. There are no x86 specific comparisons that are unsigned so its | ||
// unneeded. | ||
if (!OriginalDemandedBits.isSignMask()) | ||
return SDValue(); | ||
|
||
EVT OpVT = Op0.getValueType(); | ||
// We need need nofpclass(nan inf nzero) to handle floats. | ||
auto hasOkayFPFlags = [](SDValue Op) { | ||
return Op.getOpcode() == ISD::SINT_TO_FP || | ||
Op.getOpcode() == ISD::UINT_TO_FP || | ||
(Op->getFlags().hasNoNaNs() && Op->getFlags().hasNoInfs() && | ||
Op->getFlags().hasNoSignedZeros()); | ||
}; | ||
|
||
if (OpVT.isFloatingPoint() && !hasOkayFPFlags(Op0)) | ||
return SDValue(); | ||
|
||
auto ValsEq = [OpVT](const APInt &V0, APInt V1) -> bool { | ||
if (OpVT.isFloatingPoint()) { | ||
const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(OpVT); | ||
return V0.eq(APFloat(Sem, V1).bitcastToAPInt()); | ||
} | ||
return V0.eq(V1); | ||
}; | ||
|
||
// Assume we canonicalized constants to Op1. That isn't always true but we | ||
// call this function twice with inverted CC/Operands so its fine either way. | ||
APInt Op1C; | ||
unsigned ValWidth = OriginalDemandedBits.getBitWidth(); | ||
if (ISD::isConstantSplatVectorAllZeros(Op1.getNode())) { | ||
Op1C = APInt::getZero(ValWidth); | ||
} else if (ISD::isConstantSplatVectorAllOnes(Op1.getNode())) { | ||
Op1C = APInt::getAllOnes(ValWidth); | ||
} else if (auto *C = dyn_cast<ConstantFPSDNode>(Op1)) { | ||
Op1C = C->getValueAPF().bitcastToAPInt(); | ||
} else if (auto *C = dyn_cast<ConstantSDNode>(Op1)) { | ||
Op1C = C->getAPIntValue(); | ||
} else if (ISD::isConstantSplatVector(Op1.getNode(), Op1C)) { | ||
// isConstantSplatVector sets `Op1C`. | ||
} else { | ||
return SDValue(); | ||
} | ||
|
||
bool Not = false; | ||
bool Okay = false; | ||
assert(OriginalDemandedBits.getBitWidth() == Op1C.getBitWidth() && | ||
"Invalid constant operand"); | ||
|
||
switch (CC) { | ||
case ISD::SETGE: | ||
case ISD::SETOGE: | ||
Not = true; | ||
[[fallthrough]]; | ||
case ISD::SETLT: | ||
case ISD::SETOLT: | ||
// signbit(sext(x s< 0)) == signbit(x) | ||
// signbit(sext(x s>= 0)) == signbit(~x) | ||
Okay = ValsEq(Op1C, APInt::getZero(ValWidth)); | ||
// For float ops we need to ensure Op0 is de-norm. Otherwise DAZ can break | ||
// this fold. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was a bad bug, patch is no longer so sexy. |
||
// NB: We only need de-norm check here, for the rest of the constants any | ||
// relationship with a de-norm value and zero will be identical. | ||
if (Okay && OpVT.isFloatingPoint()) { | ||
// Values from integers are always normal. | ||
if (Op0.getOpcode() == ISD::SINT_TO_FP || | ||
Op0.getOpcode() == ISD::UINT_TO_FP) | ||
break; | ||
|
||
// See if we can prove normal with known bits. | ||
KnownBits Op0Known = | ||
DAG.computeKnownBits(Op0, OriginalDemandedElts, Depth); | ||
// Negative/positive doesn't matter. | ||
Op0Known.One.clearSignBit(); | ||
Op0Known.Zero.clearSignBit(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
// Get min normal value. | ||
const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(OpVT); | ||
KnownBits MinNormal = KnownBits::makeConstant( | ||
APFloat::getSmallestNormalized(Sem).bitcastToAPInt()); | ||
// Are we above de-norm range? | ||
std::optional<bool> Op0Normal = KnownBits::uge(Op0Known, MinNormal); | ||
Okay = Op0Normal.value_or(false); | ||
} | ||
break; | ||
case ISD::SETGT: | ||
case ISD::SETOGT: | ||
Not = true; | ||
[[fallthrough]]; | ||
case ISD::SETLE: | ||
case ISD::SETOLE: | ||
// signbit(sext(x s<= -1)) == signbit(x) | ||
// signbit(sext(x s> -1)) == signbit(~x) | ||
Okay = ValsEq(Op1C, APInt::getAllOnes(ValWidth)); | ||
break; | ||
case ISD::SETULT: | ||
Not = true; | ||
[[fallthrough]]; | ||
case ISD::SETUGE: | ||
// signbit(sext(x u>= SIGNED_MIN)) == signbit(x) | ||
// signbit(sext(x u< SIGNED_MIN)) == signbit(~x) | ||
Okay = !OpVT.isFloatingPoint() && ValsEq(Op1C, OriginalDemandedBits); | ||
break; | ||
case ISD::SETULE: | ||
Not = true; | ||
[[fallthrough]]; | ||
case ISD::SETUGT: | ||
// signbit(sext(x u> SIGNED_MAX)) == signbit(x) | ||
// signbit(sext(x u<= SIGNED_MAX)) == signbit(~x) | ||
Okay = !OpVT.isFloatingPoint() && ValsEq(Op1C, OriginalDemandedBits - 1); | ||
break; | ||
default: | ||
break; | ||
} | ||
|
||
Okay &= Not ? AllowNOT : true; | ||
if (!Okay) | ||
return SDValue(); | ||
|
||
if (!Not) | ||
return Op0; | ||
|
||
if (!OpVT.isFloatingPoint()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. const SDLoc & |
||
return DAG.getNOT(DL, Op0, OpVT); | ||
|
||
// Possible TODO: We could use `fneg` to do not. | ||
return SDValue(); | ||
} | ||
|
||
static SDValue simplifySExtOfDecomposedSetCC(SelectionDAG &DAG, const SDLoc &DL, | ||
ISD::CondCode CC, SDValue Op0, | ||
SDValue Op1, | ||
const APInt &OriginalDemandedBits, | ||
const APInt &OriginalDemandedElts, | ||
bool AllowNOT, unsigned Depth) { | ||
if (SDValue R = simplifySExtOfDecomposedSetCCImpl( | ||
DAG, DL, CC, Op0, Op1, OriginalDemandedBits, OriginalDemandedElts, | ||
AllowNOT, Depth)) | ||
return R; | ||
return simplifySExtOfDecomposedSetCCImpl( | ||
DAG, DL, ISD::getSetCCSwappedOperands(CC), Op1, Op0, OriginalDemandedBits, | ||
OriginalDemandedElts, AllowNOT, Depth); | ||
} | ||
|
||
// Simplify variable target shuffle masks based on the demanded elements. | ||
// TODO: Handle DemandedBits in mask indices as well? | ||
bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetShuffle( | ||
|
@@ -42395,13 +42675,26 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode( | |
} | ||
break; | ||
} | ||
case X86ISD::PCMPGT: | ||
// icmp sgt(0, R) == ashr(R, BitWidth-1). | ||
// iff we only need the sign bit then we can use R directly. | ||
if (OriginalDemandedBits.isSignMask() && | ||
ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode())) | ||
return TLO.CombineTo(Op, Op.getOperand(1)); | ||
case X86ISD::PCMPGT: { | ||
SDLoc DL(Op); | ||
if (SDValue R = simplifySExtOfDecomposedSetCC( | ||
TLO.DAG, DL, ISD::SETGT, Op.getOperand(0), Op.getOperand(1), | ||
OriginalDemandedBits, OriginalDemandedElts, | ||
/*AllowNOT*/ true, Depth)) | ||
return TLO.CombineTo(Op, R); | ||
break; | ||
} | ||
case X86ISD::CMPP: { | ||
SDLoc DL(Op); | ||
ISD::CondCode CC = X86::getCondForCMPPImm( | ||
cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue()); | ||
if (SDValue R = simplifySExtOfDecomposedSetCC( | ||
TLO.DAG, DL, CC, Op.getOperand(0), Op.getOperand(1), | ||
OriginalDemandedBits, OriginalDemandedElts, | ||
!(TLO.LegalOperations() && TLO.LegalTypes()), Depth)) | ||
return TLO.CombineTo(Op, R); | ||
break; | ||
} | ||
case X86ISD::MOVMSK: { | ||
SDValue Src = Op.getOperand(0); | ||
MVT SrcVT = Src.getSimpleValueType(); | ||
|
@@ -42585,13 +42878,25 @@ SDValue X86TargetLowering::SimplifyMultipleUseDemandedBitsForTargetNode( | |
if (DemandedBits.isSignMask()) | ||
return Op.getOperand(0); | ||
break; | ||
case X86ISD::PCMPGT: | ||
// icmp sgt(0, R) == ashr(R, BitWidth-1). | ||
// iff we only need the sign bit then we can use R directly. | ||
if (DemandedBits.isSignMask() && | ||
ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode())) | ||
return Op.getOperand(1); | ||
case X86ISD::PCMPGT: { | ||
SDLoc DL(Op); | ||
if (SDValue R = simplifySExtOfDecomposedSetCC( | ||
DAG, DL, ISD::SETGT, Op.getOperand(0), Op.getOperand(1), | ||
DemandedBits, DemandedElts, /*AllowNOT*/ false, Depth)) | ||
return R; | ||
break; | ||
} | ||
case X86ISD::CMPP: { | ||
SDLoc DL(Op); | ||
ISD::CondCode CC = X86::getCondForCMPPImm( | ||
cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue()); | ||
if (SDValue R = simplifySExtOfDecomposedSetCC(DAG, DL, CC, Op.getOperand(0), | ||
Op.getOperand(1), | ||
DemandedBits, DemandedElts, | ||
/*AllowNOT*/ false, Depth)) | ||
return R; | ||
break; | ||
} | ||
case X86ISD::BLENDV: { | ||
// BLENDV: Cond (MSB) ? LHS : RHS | ||
SDValue Cond = Op.getOperand(0); | ||
|
@@ -48267,7 +48572,7 @@ static SDValue combineAndShuffleNot(SDNode *N, SelectionDAG &DAG, | |
|
||
// We do not split for SSE at all, but we need to split vectors for AVX1 and | ||
// AVX2. | ||
if (!Subtarget.useAVX512Regs() && VT.is512BitVector() && | ||
if (!Subtarget.useAVX512Regs() && VT.is512BitVector() && | ||
TLI.isTypeLegal(VT.getHalfNumVectorElementsVT(*DAG.getContext()))) { | ||
SDValue LoX, HiX; | ||
std::tie(LoX, HiX) = splitVector(X, DAG, DL); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay = Op0Normal.value_or(false);