Skip to content

Commit dc3faf0

Browse files
authored
ValueTracking: Identify implied fp classes by general fcmp (#66505)
Previously we could recognize exact class tests performed by an fcmp with special values (0s, infs and smallest normal). Expand this to recognize the implied classes by a compare with a general constant. e.g. fcmp ogt x, 1 implies positive and non-0. The API should be better merged with fcmpToClassTest but that made the diff way bigger, will try to do that in a future patch.
1 parent ca39c83 commit dc3faf0

File tree

4 files changed

+492
-422
lines changed

4 files changed

+492
-422
lines changed

llvm/include/llvm/Analysis/ValueTracking.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,27 @@ std::pair<Value *, FPClassTest> fcmpToClassTest(CmpInst::Predicate Pred,
240240
const APFloat *ConstRHS,
241241
bool LookThroughSrc = true);
242242

243+
/// Compute the possible floating-point classes that \p LHS could be based on an
244+
/// fcmp returning true. Returns { TestedValue, ClassesIfTrue, ClassesIfFalse }
245+
///
246+
/// If the compare returns an exact class test, ClassesIfTrue == ~ClassesIfFalse
247+
///
248+
/// This is a less exact version of fcmpToClassTest (e.g. fcmpToClassTest will
249+
/// only succeed for a test of x > 0 implies positive, but not x > 1).
250+
///
251+
/// If \p LookThroughSrc is true, consider the input value when computing the
252+
/// mask. This may look through sign bit operations.
253+
///
254+
/// If \p LookThroughSrc is false, ignore the source value (i.e. the first pair
255+
/// element will always be LHS.
256+
///
257+
std::tuple<Value *, FPClassTest, FPClassTest>
258+
fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
259+
const APFloat *ConstRHS, bool LookThroughSrc = true);
260+
std::tuple<Value *, FPClassTest, FPClassTest>
261+
fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
262+
Value *RHS, bool LookThroughSrc = true);
263+
243264
struct KnownFPClass {
244265
/// Floating-point classes the value could be one of.
245266
FPClassTest KnownFPClasses = fcAllFlags;

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 158 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4248,6 +4248,139 @@ llvm::fcmpToClassTest(FCmpInst::Predicate Pred, const Function &F, Value *LHS,
42484248
return {Src, Mask};
42494249
}
42504250

4251+
std::tuple<Value *, FPClassTest, FPClassTest>
4252+
llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
4253+
const APFloat *ConstRHS, bool LookThroughSrc) {
4254+
auto [Val, ClassMask] =
4255+
fcmpToClassTest(Pred, F, LHS, ConstRHS, LookThroughSrc);
4256+
if (Val)
4257+
return {Val, ClassMask, ~ClassMask};
4258+
4259+
FPClassTest RHSClass = ConstRHS->classify();
4260+
assert((RHSClass == fcPosNormal || RHSClass == fcNegNormal ||
4261+
RHSClass == fcPosSubnormal || RHSClass == fcNegSubnormal) &&
4262+
"should have been recognized as an exact class test");
4263+
4264+
const bool IsNegativeRHS = (RHSClass & fcNegative) == RHSClass;
4265+
const bool IsPositiveRHS = (RHSClass & fcPositive) == RHSClass;
4266+
4267+
assert(IsNegativeRHS == ConstRHS->isNegative());
4268+
assert(IsPositiveRHS == !ConstRHS->isNegative());
4269+
4270+
Value *Src = LHS;
4271+
const bool IsFabs = LookThroughSrc && match(LHS, m_FAbs(m_Value(Src)));
4272+
4273+
if (IsFabs)
4274+
RHSClass = llvm::inverse_fabs(RHSClass);
4275+
4276+
if (Pred == FCmpInst::FCMP_OEQ)
4277+
return {Src, RHSClass, fcAllFlags};
4278+
4279+
if (Pred == FCmpInst::FCMP_UEQ) {
4280+
FPClassTest Class = RHSClass | fcNan;
4281+
return {Src, Class, ~fcNan};
4282+
}
4283+
4284+
if (Pred == FCmpInst::FCMP_ONE)
4285+
return {Src, ~fcNan, RHSClass};
4286+
4287+
if (Pred == FCmpInst::FCMP_UNE)
4288+
return {Src, fcAllFlags, RHSClass};
4289+
4290+
if (IsNegativeRHS) {
4291+
// TODO: Handle fneg(fabs)
4292+
if (IsFabs) {
4293+
// fabs(x) o> -k -> fcmp ord x, x
4294+
// fabs(x) u> -k -> true
4295+
// fabs(x) o< -k -> false
4296+
// fabs(x) u< -k -> fcmp uno x, x
4297+
switch (Pred) {
4298+
case FCmpInst::FCMP_OGT:
4299+
case FCmpInst::FCMP_OGE:
4300+
return {Src, ~fcNan, fcNan};
4301+
case FCmpInst::FCMP_UGT:
4302+
case FCmpInst::FCMP_UGE:
4303+
return {Src, fcAllFlags, fcNone};
4304+
case FCmpInst::FCMP_OLT:
4305+
case FCmpInst::FCMP_OLE:
4306+
return {Src, fcNone, fcAllFlags};
4307+
case FCmpInst::FCMP_ULT:
4308+
case FCmpInst::FCMP_ULE:
4309+
return {Src, fcNan, ~fcNan};
4310+
default:
4311+
break;
4312+
}
4313+
4314+
return {nullptr, fcAllFlags, fcAllFlags};
4315+
}
4316+
4317+
FPClassTest ClassesLE = fcNegInf | fcNegNormal;
4318+
FPClassTest ClassesGE = fcPositive | fcNegZero | fcNegSubnormal;
4319+
4320+
if (ConstRHS->isDenormal())
4321+
ClassesLE |= fcNegSubnormal;
4322+
else
4323+
ClassesGE |= fcNegNormal;
4324+
4325+
switch (Pred) {
4326+
case FCmpInst::FCMP_OGT:
4327+
case FCmpInst::FCMP_OGE:
4328+
return {Src, ClassesGE, ~ClassesGE | RHSClass};
4329+
case FCmpInst::FCMP_UGT:
4330+
case FCmpInst::FCMP_UGE:
4331+
return {Src, ClassesGE | fcNan, ~(ClassesGE | fcNan) | RHSClass};
4332+
case FCmpInst::FCMP_OLT:
4333+
case FCmpInst::FCMP_OLE:
4334+
return {Src, ClassesLE, ~ClassesLE | RHSClass};
4335+
case FCmpInst::FCMP_ULT:
4336+
case FCmpInst::FCMP_ULE:
4337+
return {Src, ClassesLE | fcNan, ~(ClassesLE | fcNan) | RHSClass};
4338+
default:
4339+
break;
4340+
}
4341+
} else if (IsPositiveRHS) {
4342+
FPClassTest ClassesGE = fcPosNormal | fcPosInf;
4343+
FPClassTest ClassesLE = fcNegative | fcPosZero | fcPosNormal;
4344+
if (ConstRHS->isDenormal())
4345+
ClassesGE |= fcPosNormal;
4346+
else
4347+
ClassesLE |= fcPosSubnormal;
4348+
4349+
if (IsFabs) {
4350+
ClassesGE = llvm::inverse_fabs(ClassesGE);
4351+
ClassesLE = llvm::inverse_fabs(ClassesLE);
4352+
}
4353+
4354+
switch (Pred) {
4355+
case FCmpInst::FCMP_OGT:
4356+
case FCmpInst::FCMP_OGE:
4357+
return {Src, ClassesGE, ~ClassesGE | RHSClass};
4358+
case FCmpInst::FCMP_UGT:
4359+
case FCmpInst::FCMP_UGE:
4360+
return {Src, ClassesGE | fcNan, ~(ClassesGE | fcNan) | RHSClass};
4361+
case FCmpInst::FCMP_OLT:
4362+
case FCmpInst::FCMP_OLE:
4363+
return {Src, ClassesLE, ~ClassesLE | RHSClass};
4364+
case FCmpInst::FCMP_ULT:
4365+
case FCmpInst::FCMP_ULE:
4366+
return {Src, ClassesLE | fcNan, ~(ClassesLE | fcNan) | RHSClass};
4367+
default:
4368+
break;
4369+
}
4370+
}
4371+
4372+
return {nullptr, fcAllFlags, fcAllFlags};
4373+
}
4374+
4375+
std::tuple<Value *, FPClassTest, FPClassTest>
4376+
llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
4377+
Value *RHS, bool LookThroughSrc) {
4378+
const APFloat *ConstRHS;
4379+
if (!match(RHS, m_APFloatAllowUndef(ConstRHS)))
4380+
return {nullptr, fcAllFlags, fcNone};
4381+
return fcmpImpliesClass(Pred, F, LHS, ConstRHS, LookThroughSrc);
4382+
}
4383+
42514384
static FPClassTest computeKnownFPClassFromAssumes(const Value *V,
42524385
const SimplifyQuery &Q) {
42534386
FPClassTest KnownFromAssume = fcAllFlags;
@@ -4272,18 +4405,21 @@ static FPClassTest computeKnownFPClassFromAssumes(const Value *V,
42724405
Value *LHS, *RHS;
42734406
uint64_t ClassVal = 0;
42744407
if (match(I->getArgOperand(0), m_FCmp(Pred, m_Value(LHS), m_Value(RHS)))) {
4275-
auto [TestedValue, TestedMask] =
4276-
fcmpToClassTest(Pred, *F, LHS, RHS, true);
4277-
// First see if we can fold in fabs/fneg into the test.
4278-
if (TestedValue == V)
4279-
KnownFromAssume &= TestedMask;
4280-
else {
4281-
// Try again without the lookthrough if we found a different source
4282-
// value.
4283-
auto [TestedValue, TestedMask] =
4284-
fcmpToClassTest(Pred, *F, LHS, RHS, false);
4285-
if (TestedValue == V)
4286-
KnownFromAssume &= TestedMask;
4408+
const APFloat *CRHS;
4409+
if (match(RHS, m_APFloat(CRHS))) {
4410+
// First see if we can fold in fabs/fneg into the test.
4411+
auto [CmpVal, MaskIfTrue, MaskIfFalse] =
4412+
fcmpImpliesClass(Pred, *F, LHS, CRHS, true);
4413+
if (CmpVal == V)
4414+
KnownFromAssume &= MaskIfTrue;
4415+
else {
4416+
// Try again without the lookthrough if we found a different source
4417+
// value.
4418+
auto [CmpVal, MaskIfTrue, MaskIfFalse] =
4419+
fcmpImpliesClass(Pred, *F, LHS, CRHS, false);
4420+
if (CmpVal == V)
4421+
KnownFromAssume &= MaskIfTrue;
4422+
}
42874423
}
42884424
} else if (match(I->getArgOperand(0),
42894425
m_Intrinsic<Intrinsic::is_fpclass>(
@@ -4431,7 +4567,8 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
44314567
FPClassTest FilterRHS = fcAllFlags;
44324568

44334569
Value *TestedValue = nullptr;
4434-
FPClassTest TestedMask = fcNone;
4570+
FPClassTest MaskIfTrue = fcAllFlags;
4571+
FPClassTest MaskIfFalse = fcAllFlags;
44354572
uint64_t ClassVal = 0;
44364573
const Function *F = cast<Instruction>(Op)->getFunction();
44374574
CmpInst::Predicate Pred;
@@ -4443,20 +4580,22 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
44434580
// TODO: In some degenerate cases we can infer something if we try again
44444581
// without looking through sign operations.
44454582
bool LookThroughFAbsFNeg = CmpLHS != LHS && CmpLHS != RHS;
4446-
std::tie(TestedValue, TestedMask) =
4447-
fcmpToClassTest(Pred, *F, CmpLHS, CmpRHS, LookThroughFAbsFNeg);
4583+
std::tie(TestedValue, MaskIfTrue, MaskIfFalse) =
4584+
fcmpImpliesClass(Pred, *F, CmpLHS, CmpRHS, LookThroughFAbsFNeg);
44484585
} else if (match(Cond,
44494586
m_Intrinsic<Intrinsic::is_fpclass>(
44504587
m_Value(TestedValue), m_ConstantInt(ClassVal)))) {
4451-
TestedMask = static_cast<FPClassTest>(ClassVal);
4588+
FPClassTest TestedMask = static_cast<FPClassTest>(ClassVal);
4589+
MaskIfTrue = TestedMask;
4590+
MaskIfFalse = ~TestedMask;
44524591
}
44534592

44544593
if (TestedValue == LHS) {
44554594
// match !isnan(x) ? x : y
4456-
FilterLHS = TestedMask;
4457-
} else if (TestedValue == RHS) {
4595+
FilterLHS = MaskIfTrue;
4596+
} else if (TestedValue == RHS) { // && IsExactClass
44584597
// match !isnan(x) ? y : x
4459-
FilterRHS = ~TestedMask;
4598+
FilterRHS = MaskIfFalse;
44604599
}
44614600

44624601
KnownFPClass Known2;

0 commit comments

Comments
 (0)