-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[X86] Try Folding icmp of v8i32 -> fcmp of v8f32 on AVX #82290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23391,6 +23391,136 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget, | |
} | ||
} | ||
|
||
// We get bad codegen for v8i32 compares on avx targets (without avx2) so if | ||
// possible convert to a v8f32 compare. | ||
if (VTOp0 == MVT::v8i32 && Subtarget.hasAVX() && !Subtarget.hasAVX2()) { | ||
std::optional<KnownBits> KnownOps[2]; | ||
// Check if an op is known to be in a certain range. | ||
auto OpInRange = [&DAG, Op, &KnownOps](unsigned OpNo, bool CmpLT, | ||
const APInt Bound) { | ||
if (!KnownOps[OpNo].has_value()) | ||
KnownOps[OpNo] = DAG.computeKnownBits(Op.getOperand(OpNo)); | ||
|
||
if (KnownOps[OpNo]->isUnknown()) | ||
return false; | ||
|
||
std::optional<bool> Res; | ||
if (CmpLT) | ||
Res = KnownBits::ult(*KnownOps[OpNo], KnownBits::makeConstant(Bound)); | ||
else | ||
Res = KnownBits::ugt(*KnownOps[OpNo], KnownBits::makeConstant(Bound)); | ||
return Res.value_or(false); | ||
}; | ||
|
||
bool OkayCvt = false; | ||
bool OkayBitcast = false; | ||
|
||
const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(MVT::f32); | ||
|
||
// For cvt up to 1 << (Significand Precision), (1 << 24 for ieee float) | ||
const APInt MaxConvertableCvt = | ||
APInt::getOneBitSet(32, APFloat::semanticsPrecision(Sem)); | ||
// For bitcast up to (and including) first inf representation (0x7f800000 + | ||
// 1 for ieee float) | ||
const APInt MaxConvertableBitcast = | ||
APFloat::getInf(Sem).bitcastToAPInt() + 1; | ||
// For bitcast we also exclude de-norm values. This is absolutely necessary | ||
// for strict semantic correctness, but DAZ (de-norm as zero) will break if | ||
// we don't have this check. | ||
const APInt MinConvertableBitcast = | ||
APFloat::getSmallestNormalized(Sem).bitcastToAPInt() - 1; | ||
|
||
assert( | ||
MaxConvertableBitcast.getBitWidth() == 32 && | ||
MaxConvertableCvt == (1U << 24) && | ||
MaxConvertableBitcast == 0x7f800001 && | ||
MinConvertableBitcast.isNonNegative() && | ||
MaxConvertableBitcast.sgt(MinConvertableBitcast) && | ||
"This transform has only been verified to IEEE Single Precision Float"); | ||
|
||
// For bitcast we need both lhs/op1 u< MaxConvertableBitcast | ||
// NB: It might be worth it to enable to bitcast version for unsigned avx2 | ||
// comparisons as they typically require multiple instructions to lower | ||
// (they don't fit `vpcmpeq`/`vpcmpgt` well). | ||
if (OpInRange(1, /*CmpLT*/ true, MaxConvertableBitcast) && | ||
OpInRange(1, /*CmpLT*/ false, MinConvertableBitcast) && | ||
OpInRange(0, /*CmpLT*/ true, MaxConvertableBitcast) && | ||
OpInRange(0, /*CmpLT*/ false, MinConvertableBitcast)) { | ||
OkayBitcast = true; | ||
} | ||
// We want to convert icmp -> fcmp using `sitofp` iff one of the converts | ||
// will be constant folded. | ||
else if ((DAG.isConstantValueOfAnyType(peekThroughBitcasts(Op1)) || | ||
DAG.isConstantValueOfAnyType(peekThroughBitcasts(Op0)))) { | ||
if (isUnsignedIntSetCC(Cond)) { | ||
// For cvt + unsigned compare we need both lhs/rhs >= 0 and either lhs | ||
// or rhs < MaxConvertableCvt | ||
|
||
if (OpInRange(1, /*CmpLT*/ true, APInt::getSignedMinValue(32)) && | ||
OpInRange(0, /*CmpLT*/ true, APInt::getSignedMinValue(32)) && | ||
(OpInRange(1, /*CmpLT*/ true, MaxConvertableCvt) || | ||
OpInRange(0, /*CmpLT*/ true, MaxConvertableCvt))) | ||
OkayCvt = true; | ||
} else { | ||
// For cvt + signed compare we need abs(lhs) or abs(rhs) < | ||
// MaxConvertableCvt | ||
if (OpInRange(1, /*CmpLT*/ true, MaxConvertableCvt) || | ||
OpInRange(1, /*CmpLT*/ false, -MaxConvertableCvt) || | ||
OpInRange(0, /*CmpLT*/ true, MaxConvertableCvt) || | ||
OpInRange(0, /*CmpLT*/ false, -MaxConvertableCvt)) | ||
OkayCvt = true; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we need signed comparisons here - OpInRange just does unsigned comparisons? Ideally we'd use NumSignBits checks as an alternative to also check for being in signed bounds. |
||
} | ||
} | ||
// TODO: If we can't prove any of the ranges, we could unconditionally lower | ||
// `(icmp eq lhs, rhs)` as `(icmp eq (int_to_fp (xor lhs, rhs)), zero)` | ||
if (OkayBitcast || OkayCvt) { | ||
switch (Cond) { | ||
default: | ||
llvm_unreachable("Unexpected SETCC condition"); | ||
// Get the new FP condition. Note for the unsigned conditions we have | ||
// verified its okay to convert to the signed version. | ||
case ISD::SETULT: | ||
case ISD::SETLT: | ||
Cond = ISD::SETOLT; | ||
break; | ||
case ISD::SETUGT: | ||
case ISD::SETGT: | ||
Cond = ISD::SETOGT; | ||
break; | ||
case ISD::SETULE: | ||
case ISD::SETLE: | ||
Cond = ISD::SETOLE; | ||
break; | ||
case ISD::SETUGE: | ||
case ISD::SETGE: | ||
Cond = ISD::SETOGE; | ||
break; | ||
case ISD::SETEQ: | ||
Cond = ISD::SETOEQ; | ||
break; | ||
case ISD::SETNE: | ||
Cond = ISD::SETONE; | ||
break; | ||
} | ||
|
||
MVT FpVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount()); | ||
SDNodeFlags Flags; | ||
Flags.setNoNaNs(true); | ||
Flags.setNoInfs(true); | ||
Flags.setNoSignedZeros(true); | ||
if (OkayBitcast) { | ||
Op0 = DAG.getBitcast(FpVT, Op0); | ||
Op1 = DAG.getBitcast(FpVT, Op1); | ||
} else { | ||
Op0 = DAG.getNode(ISD::SINT_TO_FP, dl, FpVT, Op0); | ||
Op1 = DAG.getNode(ISD::SINT_TO_FP, dl, FpVT, Op1); | ||
} | ||
Op0->setFlags(Flags); | ||
Op1->setFlags(Flags); | ||
return DAG.getSetCC(dl, VT, Op0, Op1, Cond); | ||
} | ||
} | ||
|
||
// Break 256-bit integer vector compare into smaller ones. | ||
if (VT.is256BitVector() && !Subtarget.hasInt256()) | ||
return splitIntVSETCC(VT, Op0, Op1, Cond, DAG, dl); | ||
|
@@ -41216,6 +41346,156 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG, | |
return SDValue(); | ||
} | ||
|
||
// Simplify a decomposed (sext (setcc)). Assumes prior check that | ||
// bitwidth(sext)==bitwidth(setcc operands). | ||
static SDValue simplifySExtOfDecomposedSetCCImpl( | ||
SelectionDAG &DAG, const SDLoc &DL, ISD::CondCode CC, SDValue Op0, | ||
SDValue Op1, const APInt &OriginalDemandedBits, | ||
const APInt &OriginalDemandedElts, bool AllowNOT, unsigned Depth) { | ||
// Possible TODO: We could handle any power of two demanded bit + unsigned | ||
// comparison. There are no x86 specific comparisons that are unsigned so its | ||
// unneeded. | ||
if (!OriginalDemandedBits.isSignMask()) | ||
return SDValue(); | ||
|
||
EVT OpVT = Op0.getValueType(); | ||
// We need need nofpclass(nan inf nzero) to handle floats. | ||
auto hasOkayFPFlags = [](SDValue Op) { | ||
return Op.getOpcode() == ISD::SINT_TO_FP || | ||
Op.getOpcode() == ISD::UINT_TO_FP || | ||
(Op->getFlags().hasNoNaNs() && Op->getFlags().hasNoInfs() && | ||
Op->getFlags().hasNoSignedZeros()); | ||
}; | ||
|
||
if (OpVT.isFloatingPoint() && !hasOkayFPFlags(Op0)) | ||
return SDValue(); | ||
|
||
auto ValsEq = [OpVT](const APInt &V0, APInt V1) -> bool { | ||
if (OpVT.isFloatingPoint()) { | ||
const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(OpVT); | ||
return V0.eq(APFloat(Sem, V1).bitcastToAPInt()); | ||
} | ||
return V0.eq(V1); | ||
}; | ||
|
||
// Assume we canonicalized constants to Op1. That isn't always true but we | ||
// call this function twice with inverted CC/Operands so its fine either way. | ||
APInt Op1C; | ||
unsigned ValWidth = OriginalDemandedBits.getBitWidth(); | ||
if (ISD::isConstantSplatVectorAllZeros(Op1.getNode())) { | ||
Op1C = APInt::getZero(ValWidth); | ||
} else if (ISD::isConstantSplatVectorAllOnes(Op1.getNode())) { | ||
Op1C = APInt::getAllOnes(ValWidth); | ||
} else if (auto *C = dyn_cast<ConstantFPSDNode>(Op1)) { | ||
Op1C = C->getValueAPF().bitcastToAPInt(); | ||
} else if (auto *C = dyn_cast<ConstantSDNode>(Op1)) { | ||
Op1C = C->getAPIntValue(); | ||
} else if (ISD::isConstantSplatVector(Op1.getNode(), Op1C)) { | ||
// isConstantSplatVector sets `Op1C`. | ||
} else { | ||
return SDValue(); | ||
} | ||
|
||
bool Not = false; | ||
bool Okay = false; | ||
assert(OriginalDemandedBits.getBitWidth() == Op1C.getBitWidth() && | ||
"Invalid constant operand"); | ||
|
||
switch (CC) { | ||
case ISD::SETGE: | ||
case ISD::SETOGE: | ||
Not = true; | ||
[[fallthrough]]; | ||
case ISD::SETLT: | ||
case ISD::SETOLT: | ||
// signbit(sext(x s< 0)) == signbit(x) | ||
// signbit(sext(x s>= 0)) == signbit(~x) | ||
Okay = ValsEq(Op1C, APInt::getZero(ValWidth)); | ||
// For float ops we need to ensure Op0 is de-norm. Otherwise DAZ can break | ||
// this fold. | ||
// NB: We only need de-norm check here, for the rest of the constants any | ||
// relationship with a de-norm value and zero will be identical. | ||
if (Okay && OpVT.isFloatingPoint()) { | ||
// Values from integers are always normal. | ||
if (Op0.getOpcode() == ISD::SINT_TO_FP || | ||
Op0.getOpcode() == ISD::UINT_TO_FP) | ||
break; | ||
|
||
// See if we can prove normal with known bits. | ||
KnownBits Op0Known = | ||
DAG.computeKnownBits(Op0, OriginalDemandedElts, Depth); | ||
// Negative/positive doesn't matter. | ||
Op0Known.One.clearSignBit(); | ||
Op0Known.Zero.clearSignBit(); | ||
|
||
// Get min normal value. | ||
const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(OpVT); | ||
KnownBits MinNormal = KnownBits::makeConstant( | ||
APFloat::getSmallestNormalized(Sem).bitcastToAPInt()); | ||
// Are we above de-norm range? | ||
std::optional<bool> Op0Normal = KnownBits::uge(Op0Known, MinNormal); | ||
Okay = Op0Normal.value_or(false); | ||
} | ||
break; | ||
case ISD::SETGT: | ||
case ISD::SETOGT: | ||
Not = true; | ||
[[fallthrough]]; | ||
case ISD::SETLE: | ||
case ISD::SETOLE: | ||
// signbit(sext(x s<= -1)) == signbit(x) | ||
// signbit(sext(x s> -1)) == signbit(~x) | ||
Okay = ValsEq(Op1C, APInt::getAllOnes(ValWidth)); | ||
break; | ||
case ISD::SETULT: | ||
Not = true; | ||
[[fallthrough]]; | ||
case ISD::SETUGE: | ||
// signbit(sext(x u>= SIGNED_MIN)) == signbit(x) | ||
// signbit(sext(x u< SIGNED_MIN)) == signbit(~x) | ||
Okay = !OpVT.isFloatingPoint() && ValsEq(Op1C, OriginalDemandedBits); | ||
break; | ||
case ISD::SETULE: | ||
Not = true; | ||
[[fallthrough]]; | ||
case ISD::SETUGT: | ||
// signbit(sext(x u> SIGNED_MAX)) == signbit(x) | ||
// signbit(sext(x u<= SIGNED_MAX)) == signbit(~x) | ||
Okay = !OpVT.isFloatingPoint() && ValsEq(Op1C, OriginalDemandedBits - 1); | ||
break; | ||
default: | ||
break; | ||
} | ||
|
||
Okay &= Not ? AllowNOT : true; | ||
if (!Okay) | ||
return SDValue(); | ||
|
||
if (!Not) | ||
return Op0; | ||
|
||
if (!OpVT.isFloatingPoint()) | ||
return DAG.getNOT(DL, Op0, OpVT); | ||
|
||
// Possible TODO: We could use `fneg` to do not. | ||
return SDValue(); | ||
} | ||
|
||
static SDValue simplifySExtOfDecomposedSetCC(SelectionDAG &DAG, const SDLoc &DL, | ||
ISD::CondCode CC, SDValue Op0, | ||
SDValue Op1, | ||
const APInt &OriginalDemandedBits, | ||
const APInt &OriginalDemandedElts, | ||
bool AllowNOT, unsigned Depth) { | ||
if (SDValue R = simplifySExtOfDecomposedSetCCImpl( | ||
DAG, DL, CC, Op0, Op1, OriginalDemandedBits, OriginalDemandedElts, | ||
AllowNOT, Depth)) | ||
return R; | ||
return simplifySExtOfDecomposedSetCCImpl( | ||
DAG, DL, ISD::getSetCCSwappedOperands(CC), Op1, Op0, OriginalDemandedBits, | ||
OriginalDemandedElts, AllowNOT, Depth); | ||
} | ||
|
||
// Simplify variable target shuffle masks based on the demanded elements. | ||
// TODO: Handle DemandedBits in mask indices as well? | ||
bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetShuffle( | ||
|
@@ -42395,13 +42675,26 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode( | |
} | ||
break; | ||
} | ||
case X86ISD::PCMPGT: | ||
// icmp sgt(0, R) == ashr(R, BitWidth-1). | ||
// iff we only need the sign bit then we can use R directly. | ||
if (OriginalDemandedBits.isSignMask() && | ||
ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode())) | ||
return TLO.CombineTo(Op, Op.getOperand(1)); | ||
case X86ISD::PCMPGT: { | ||
SDLoc DL(Op); | ||
if (SDValue R = simplifySExtOfDecomposedSetCC( | ||
TLO.DAG, DL, ISD::SETGT, Op.getOperand(0), Op.getOperand(1), | ||
OriginalDemandedBits, OriginalDemandedElts, | ||
/*AllowNOT*/ true, Depth)) | ||
return TLO.CombineTo(Op, R); | ||
break; | ||
} | ||
case X86ISD::CMPP: { | ||
SDLoc DL(Op); | ||
ISD::CondCode CC = X86::getCondForCMPPImm( | ||
cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue()); | ||
if (SDValue R = simplifySExtOfDecomposedSetCC( | ||
TLO.DAG, DL, CC, Op.getOperand(0), Op.getOperand(1), | ||
OriginalDemandedBits, OriginalDemandedElts, | ||
!(TLO.LegalOperations() && TLO.LegalTypes()), Depth)) | ||
return TLO.CombineTo(Op, R); | ||
break; | ||
} | ||
case X86ISD::MOVMSK: { | ||
SDValue Src = Op.getOperand(0); | ||
MVT SrcVT = Src.getSimpleValueType(); | ||
|
@@ -42585,13 +42878,25 @@ SDValue X86TargetLowering::SimplifyMultipleUseDemandedBitsForTargetNode( | |
if (DemandedBits.isSignMask()) | ||
return Op.getOperand(0); | ||
break; | ||
case X86ISD::PCMPGT: | ||
// icmp sgt(0, R) == ashr(R, BitWidth-1). | ||
// iff we only need the sign bit then we can use R directly. | ||
if (DemandedBits.isSignMask() && | ||
ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode())) | ||
return Op.getOperand(1); | ||
case X86ISD::PCMPGT: { | ||
SDLoc DL(Op); | ||
if (SDValue R = simplifySExtOfDecomposedSetCC( | ||
DAG, DL, ISD::SETGT, Op.getOperand(0), Op.getOperand(1), | ||
DemandedBits, DemandedElts, /*AllowNOT*/ false, Depth)) | ||
return R; | ||
break; | ||
} | ||
case X86ISD::CMPP: { | ||
SDLoc DL(Op); | ||
ISD::CondCode CC = X86::getCondForCMPPImm( | ||
cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue()); | ||
if (SDValue R = simplifySExtOfDecomposedSetCC(DAG, DL, CC, Op.getOperand(0), | ||
Op.getOperand(1), | ||
DemandedBits, DemandedElts, | ||
/*AllowNOT*/ false, Depth)) | ||
return R; | ||
break; | ||
} | ||
case X86ISD::BLENDV: { | ||
// BLENDV: Cond (MSB) ? LHS : RHS | ||
SDValue Cond = Op.getOperand(0); | ||
|
@@ -48267,7 +48572,7 @@ static SDValue combineAndShuffleNot(SDNode *N, SelectionDAG &DAG, | |
|
||
// We do not split for SSE at all, but we need to split vectors for AVX1 and | ||
// AVX2. | ||
if (!Subtarget.useAVX512Regs() && VT.is512BitVector() && | ||
if (!Subtarget.useAVX512Regs() && VT.is512BitVector() && | ||
TLI.isTypeLegal(VT.getHalfNumVectorElementsVT(*DAG.getContext()))) { | ||
SDValue LoX, HiX; | ||
std::tie(LoX, HiX) = splitVector(X, DAG, DL); | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@goldsteinn The comment says both lhs/rhs >= 0 but the comparisons are with MIN_INT ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its unsigned compare. So
foo u< MIN_INT
->foo s>= 0
.