Skip to content

[X86] Try Folding icmp of v8i32 -> fcmp of v8f32 on AVX #82290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
331 changes: 318 additions & 13 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23391,6 +23391,136 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
}
}

// We get bad codegen for v8i32 compares on avx targets (without avx2) so if
// possible convert to a v8f32 compare.
if (VTOp0 == MVT::v8i32 && Subtarget.hasAVX() && !Subtarget.hasAVX2()) {
std::optional<KnownBits> KnownOps[2];
// Check if an op is known to be in a certain range.
auto OpInRange = [&DAG, Op, &KnownOps](unsigned OpNo, bool CmpLT,
const APInt Bound) {
if (!KnownOps[OpNo].has_value())
KnownOps[OpNo] = DAG.computeKnownBits(Op.getOperand(OpNo));

if (KnownOps[OpNo]->isUnknown())
return false;

std::optional<bool> Res;
if (CmpLT)
Res = KnownBits::ult(*KnownOps[OpNo], KnownBits::makeConstant(Bound));
else
Res = KnownBits::ugt(*KnownOps[OpNo], KnownBits::makeConstant(Bound));
return Res.value_or(false);
};

bool OkayCvt = false;
bool OkayBitcast = false;

const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(MVT::f32);

// For cvt up to 1 << (Significand Precision), (1 << 24 for ieee float)
const APInt MaxConvertableCvt =
APInt::getOneBitSet(32, APFloat::semanticsPrecision(Sem));
// For bitcast up to (and including) first inf representation (0x7f800000 +
// 1 for ieee float)
const APInt MaxConvertableBitcast =
APFloat::getInf(Sem).bitcastToAPInt() + 1;
// For bitcast we also exclude de-norm values. This is absolutely necessary
// for strict semantic correctness, but DAZ (de-norm as zero) will break if
// we don't have this check.
const APInt MinConvertableBitcast =
APFloat::getSmallestNormalized(Sem).bitcastToAPInt() - 1;

assert(
MaxConvertableBitcast.getBitWidth() == 32 &&
MaxConvertableCvt == (1U << 24) &&
MaxConvertableBitcast == 0x7f800001 &&
MinConvertableBitcast.isNonNegative() &&
MaxConvertableBitcast.sgt(MinConvertableBitcast) &&
"This transform has only been verified to IEEE Single Precision Float");

// For bitcast we need both lhs/op1 u< MaxConvertableBitcast
// NB: It might be worth it to enable to bitcast version for unsigned avx2
// comparisons as they typically require multiple instructions to lower
// (they don't fit `vpcmpeq`/`vpcmpgt` well).
if (OpInRange(1, /*CmpLT*/ true, MaxConvertableBitcast) &&
OpInRange(1, /*CmpLT*/ false, MinConvertableBitcast) &&
OpInRange(0, /*CmpLT*/ true, MaxConvertableBitcast) &&
OpInRange(0, /*CmpLT*/ false, MinConvertableBitcast)) {
OkayBitcast = true;
}
// We want to convert icmp -> fcmp using `sitofp` iff one of the converts
// will be constant folded.
else if ((DAG.isConstantValueOfAnyType(peekThroughBitcasts(Op1)) ||
DAG.isConstantValueOfAnyType(peekThroughBitcasts(Op0)))) {
if (isUnsignedIntSetCC(Cond)) {
// For cvt + unsigned compare we need both lhs/rhs >= 0 and either lhs
// or rhs < MaxConvertableCvt

if (OpInRange(1, /*CmpLT*/ true, APInt::getSignedMinValue(32)) &&
OpInRange(0, /*CmpLT*/ true, APInt::getSignedMinValue(32)) &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@goldsteinn The comment says both lhs/rhs >= 0 but the comparisons are with MIN_INT ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its unsigned compare. So foo u< MIN_INT -> foo s>= 0.

(OpInRange(1, /*CmpLT*/ true, MaxConvertableCvt) ||
OpInRange(0, /*CmpLT*/ true, MaxConvertableCvt)))
OkayCvt = true;
} else {
// For cvt + signed compare we need abs(lhs) or abs(rhs) <
// MaxConvertableCvt
if (OpInRange(1, /*CmpLT*/ true, MaxConvertableCvt) ||
OpInRange(1, /*CmpLT*/ false, -MaxConvertableCvt) ||
OpInRange(0, /*CmpLT*/ true, MaxConvertableCvt) ||
OpInRange(0, /*CmpLT*/ false, -MaxConvertableCvt))
OkayCvt = true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need signed comparisons here - OpInRange just does unsigned comparisons? Ideally we'd use NumSignBits checks as an alternative to also check for being in signed bounds.

}
}
// TODO: If we can't prove any of the ranges, we could unconditionally lower
// `(icmp eq lhs, rhs)` as `(icmp eq (int_to_fp (xor lhs, rhs)), zero)`
if (OkayBitcast || OkayCvt) {
switch (Cond) {
default:
llvm_unreachable("Unexpected SETCC condition");
// Get the new FP condition. Note for the unsigned conditions we have
// verified its okay to convert to the signed version.
case ISD::SETULT:
case ISD::SETLT:
Cond = ISD::SETOLT;
break;
case ISD::SETUGT:
case ISD::SETGT:
Cond = ISD::SETOGT;
break;
case ISD::SETULE:
case ISD::SETLE:
Cond = ISD::SETOLE;
break;
case ISD::SETUGE:
case ISD::SETGE:
Cond = ISD::SETOGE;
break;
case ISD::SETEQ:
Cond = ISD::SETOEQ;
break;
case ISD::SETNE:
Cond = ISD::SETONE;
break;
}

MVT FpVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
SDNodeFlags Flags;
Flags.setNoNaNs(true);
Flags.setNoInfs(true);
Flags.setNoSignedZeros(true);
if (OkayBitcast) {
Op0 = DAG.getBitcast(FpVT, Op0);
Op1 = DAG.getBitcast(FpVT, Op1);
} else {
Op0 = DAG.getNode(ISD::SINT_TO_FP, dl, FpVT, Op0);
Op1 = DAG.getNode(ISD::SINT_TO_FP, dl, FpVT, Op1);
}
Op0->setFlags(Flags);
Op1->setFlags(Flags);
return DAG.getSetCC(dl, VT, Op0, Op1, Cond);
}
}

// Break 256-bit integer vector compare into smaller ones.
if (VT.is256BitVector() && !Subtarget.hasInt256())
return splitIntVSETCC(VT, Op0, Op1, Cond, DAG, dl);
Expand Down Expand Up @@ -41216,6 +41346,156 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

// Simplify a decomposed (sext (setcc)). Assumes prior check that
// bitwidth(sext)==bitwidth(setcc operands).
static SDValue simplifySExtOfDecomposedSetCCImpl(
SelectionDAG &DAG, const SDLoc &DL, ISD::CondCode CC, SDValue Op0,
SDValue Op1, const APInt &OriginalDemandedBits,
const APInt &OriginalDemandedElts, bool AllowNOT, unsigned Depth) {
// Possible TODO: We could handle any power of two demanded bit + unsigned
// comparison. There are no x86 specific comparisons that are unsigned so its
// unneeded.
if (!OriginalDemandedBits.isSignMask())
return SDValue();

EVT OpVT = Op0.getValueType();
// We need need nofpclass(nan inf nzero) to handle floats.
auto hasOkayFPFlags = [](SDValue Op) {
return Op.getOpcode() == ISD::SINT_TO_FP ||
Op.getOpcode() == ISD::UINT_TO_FP ||
(Op->getFlags().hasNoNaNs() && Op->getFlags().hasNoInfs() &&
Op->getFlags().hasNoSignedZeros());
};

if (OpVT.isFloatingPoint() && !hasOkayFPFlags(Op0))
return SDValue();

auto ValsEq = [OpVT](const APInt &V0, APInt V1) -> bool {
if (OpVT.isFloatingPoint()) {
const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(OpVT);
return V0.eq(APFloat(Sem, V1).bitcastToAPInt());
}
return V0.eq(V1);
};

// Assume we canonicalized constants to Op1. That isn't always true but we
// call this function twice with inverted CC/Operands so its fine either way.
APInt Op1C;
unsigned ValWidth = OriginalDemandedBits.getBitWidth();
if (ISD::isConstantSplatVectorAllZeros(Op1.getNode())) {
Op1C = APInt::getZero(ValWidth);
} else if (ISD::isConstantSplatVectorAllOnes(Op1.getNode())) {
Op1C = APInt::getAllOnes(ValWidth);
} else if (auto *C = dyn_cast<ConstantFPSDNode>(Op1)) {
Op1C = C->getValueAPF().bitcastToAPInt();
} else if (auto *C = dyn_cast<ConstantSDNode>(Op1)) {
Op1C = C->getAPIntValue();
} else if (ISD::isConstantSplatVector(Op1.getNode(), Op1C)) {
// isConstantSplatVector sets `Op1C`.
} else {
return SDValue();
}

bool Not = false;
bool Okay = false;
assert(OriginalDemandedBits.getBitWidth() == Op1C.getBitWidth() &&
"Invalid constant operand");

switch (CC) {
case ISD::SETGE:
case ISD::SETOGE:
Not = true;
[[fallthrough]];
case ISD::SETLT:
case ISD::SETOLT:
// signbit(sext(x s< 0)) == signbit(x)
// signbit(sext(x s>= 0)) == signbit(~x)
Okay = ValsEq(Op1C, APInt::getZero(ValWidth));
// For float ops we need to ensure Op0 is de-norm. Otherwise DAZ can break
// this fold.
// NB: We only need de-norm check here, for the rest of the constants any
// relationship with a de-norm value and zero will be identical.
if (Okay && OpVT.isFloatingPoint()) {
// Values from integers are always normal.
if (Op0.getOpcode() == ISD::SINT_TO_FP ||
Op0.getOpcode() == ISD::UINT_TO_FP)
break;

// See if we can prove normal with known bits.
KnownBits Op0Known =
DAG.computeKnownBits(Op0, OriginalDemandedElts, Depth);
// Negative/positive doesn't matter.
Op0Known.One.clearSignBit();
Op0Known.Zero.clearSignBit();

// Get min normal value.
const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(OpVT);
KnownBits MinNormal = KnownBits::makeConstant(
APFloat::getSmallestNormalized(Sem).bitcastToAPInt());
// Are we above de-norm range?
std::optional<bool> Op0Normal = KnownBits::uge(Op0Known, MinNormal);
Okay = Op0Normal.value_or(false);
}
break;
case ISD::SETGT:
case ISD::SETOGT:
Not = true;
[[fallthrough]];
case ISD::SETLE:
case ISD::SETOLE:
// signbit(sext(x s<= -1)) == signbit(x)
// signbit(sext(x s> -1)) == signbit(~x)
Okay = ValsEq(Op1C, APInt::getAllOnes(ValWidth));
break;
case ISD::SETULT:
Not = true;
[[fallthrough]];
case ISD::SETUGE:
// signbit(sext(x u>= SIGNED_MIN)) == signbit(x)
// signbit(sext(x u< SIGNED_MIN)) == signbit(~x)
Okay = !OpVT.isFloatingPoint() && ValsEq(Op1C, OriginalDemandedBits);
break;
case ISD::SETULE:
Not = true;
[[fallthrough]];
case ISD::SETUGT:
// signbit(sext(x u> SIGNED_MAX)) == signbit(x)
// signbit(sext(x u<= SIGNED_MAX)) == signbit(~x)
Okay = !OpVT.isFloatingPoint() && ValsEq(Op1C, OriginalDemandedBits - 1);
break;
default:
break;
}

Okay &= Not ? AllowNOT : true;
if (!Okay)
return SDValue();

if (!Not)
return Op0;

if (!OpVT.isFloatingPoint())
return DAG.getNOT(DL, Op0, OpVT);

// Possible TODO: We could use `fneg` to do not.
return SDValue();
}

static SDValue simplifySExtOfDecomposedSetCC(SelectionDAG &DAG, const SDLoc &DL,
ISD::CondCode CC, SDValue Op0,
SDValue Op1,
const APInt &OriginalDemandedBits,
const APInt &OriginalDemandedElts,
bool AllowNOT, unsigned Depth) {
if (SDValue R = simplifySExtOfDecomposedSetCCImpl(
DAG, DL, CC, Op0, Op1, OriginalDemandedBits, OriginalDemandedElts,
AllowNOT, Depth))
return R;
return simplifySExtOfDecomposedSetCCImpl(
DAG, DL, ISD::getSetCCSwappedOperands(CC), Op1, Op0, OriginalDemandedBits,
OriginalDemandedElts, AllowNOT, Depth);
}

// Simplify variable target shuffle masks based on the demanded elements.
// TODO: Handle DemandedBits in mask indices as well?
bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetShuffle(
Expand Down Expand Up @@ -42395,13 +42675,26 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
}
break;
}
case X86ISD::PCMPGT:
// icmp sgt(0, R) == ashr(R, BitWidth-1).
// iff we only need the sign bit then we can use R directly.
if (OriginalDemandedBits.isSignMask() &&
ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode()))
return TLO.CombineTo(Op, Op.getOperand(1));
case X86ISD::PCMPGT: {
SDLoc DL(Op);
if (SDValue R = simplifySExtOfDecomposedSetCC(
TLO.DAG, DL, ISD::SETGT, Op.getOperand(0), Op.getOperand(1),
OriginalDemandedBits, OriginalDemandedElts,
/*AllowNOT*/ true, Depth))
return TLO.CombineTo(Op, R);
break;
}
case X86ISD::CMPP: {
SDLoc DL(Op);
ISD::CondCode CC = X86::getCondForCMPPImm(
cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue());
if (SDValue R = simplifySExtOfDecomposedSetCC(
TLO.DAG, DL, CC, Op.getOperand(0), Op.getOperand(1),
OriginalDemandedBits, OriginalDemandedElts,
!(TLO.LegalOperations() && TLO.LegalTypes()), Depth))
return TLO.CombineTo(Op, R);
break;
}
case X86ISD::MOVMSK: {
SDValue Src = Op.getOperand(0);
MVT SrcVT = Src.getSimpleValueType();
Expand Down Expand Up @@ -42585,13 +42878,25 @@ SDValue X86TargetLowering::SimplifyMultipleUseDemandedBitsForTargetNode(
if (DemandedBits.isSignMask())
return Op.getOperand(0);
break;
case X86ISD::PCMPGT:
// icmp sgt(0, R) == ashr(R, BitWidth-1).
// iff we only need the sign bit then we can use R directly.
if (DemandedBits.isSignMask() &&
ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode()))
return Op.getOperand(1);
case X86ISD::PCMPGT: {
SDLoc DL(Op);
if (SDValue R = simplifySExtOfDecomposedSetCC(
DAG, DL, ISD::SETGT, Op.getOperand(0), Op.getOperand(1),
DemandedBits, DemandedElts, /*AllowNOT*/ false, Depth))
return R;
break;
}
case X86ISD::CMPP: {
SDLoc DL(Op);
ISD::CondCode CC = X86::getCondForCMPPImm(
cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue());
if (SDValue R = simplifySExtOfDecomposedSetCC(DAG, DL, CC, Op.getOperand(0),
Op.getOperand(1),
DemandedBits, DemandedElts,
/*AllowNOT*/ false, Depth))
return R;
break;
}
case X86ISD::BLENDV: {
// BLENDV: Cond (MSB) ? LHS : RHS
SDValue Cond = Op.getOperand(0);
Expand Down Expand Up @@ -48267,7 +48572,7 @@ static SDValue combineAndShuffleNot(SDNode *N, SelectionDAG &DAG,

// We do not split for SSE at all, but we need to split vectors for AVX1 and
// AVX2.
if (!Subtarget.useAVX512Regs() && VT.is512BitVector() &&
if (!Subtarget.useAVX512Regs() && VT.is512BitVector() &&
TLI.isTypeLegal(VT.getHalfNumVectorElementsVT(*DAG.getContext()))) {
SDValue LoX, HiX;
std::tie(LoX, HiX) = splitVector(X, DAG, DL);
Expand Down
Loading