Skip to content

ValueTracking: Identify implied fp classes by general fcmp #66505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions llvm/include/llvm/Analysis/ValueTracking.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,27 @@ std::pair<Value *, FPClassTest> fcmpToClassTest(CmpInst::Predicate Pred,
const APFloat *ConstRHS,
bool LookThroughSrc = true);

/// Compute the possible floating-point classes that \p LHS could be based on an
/// fcmp returning true. Returns { TestedValue, ClassesIfTrue, ClassesIfFalse }
///
/// If the compare returns an exact class test, ClassesIfTrue == ~ClassesIfFalse
///
/// This is a less exact version of fcmpToClassTest (e.g. fcmpToClassTest will
/// only succeed for a test of x > 0 implies positive, but not x > 1).
///
/// If \p LookThroughSrc is true, consider the input value when computing the
/// mask. This may look through sign bit operations.
///
/// If \p LookThroughSrc is false, ignore the source value (i.e. the first pair
/// element will always be LHS.
///
std::tuple<Value *, FPClassTest, FPClassTest>
fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
const APFloat *ConstRHS, bool LookThroughSrc = true);
std::tuple<Value *, FPClassTest, FPClassTest>
fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
Value *RHS, bool LookThroughSrc = true);

struct KnownFPClass {
/// Floating-point classes the value could be one of.
FPClassTest KnownFPClasses = fcAllFlags;
Expand Down
177 changes: 158 additions & 19 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4245,6 +4245,139 @@ llvm::fcmpToClassTest(FCmpInst::Predicate Pred, const Function &F, Value *LHS,
return {Src, Mask};
}

std::tuple<Value *, FPClassTest, FPClassTest>
llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
const APFloat *ConstRHS, bool LookThroughSrc) {
auto [Val, ClassMask] =
fcmpToClassTest(Pred, F, LHS, ConstRHS, LookThroughSrc);
if (Val)
return {Val, ClassMask, ~ClassMask};

FPClassTest RHSClass = ConstRHS->classify();
assert((RHSClass == fcPosNormal || RHSClass == fcNegNormal ||
RHSClass == fcPosSubnormal || RHSClass == fcNegSubnormal) &&
"should have been recognized as an exact class test");

const bool IsNegativeRHS = (RHSClass & fcNegative) == RHSClass;
const bool IsPositiveRHS = (RHSClass & fcPositive) == RHSClass;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having two flags seems overly general. RHS is always either known positive or known negative.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for a nan

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then the asserts immediately below must be broken since they effectively assert that IsNegativeRHS == !IsPositiveRHS.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out nans can't reach here and I wrote this this way for the benefit of a future patch to generalize the RHS handling to non-constants


assert(IsNegativeRHS == ConstRHS->isNegative());
assert(IsPositiveRHS == !ConstRHS->isNegative());

Value *Src = LHS;
const bool IsFabs = LookThroughSrc && match(LHS, m_FAbs(m_Value(Src)));

if (IsFabs)
RHSClass = llvm::inverse_fabs(RHSClass);

if (Pred == FCmpInst::FCMP_OEQ)
return {Src, RHSClass, fcAllFlags};

if (Pred == FCmpInst::FCMP_UEQ) {
FPClassTest Class = RHSClass | fcNan;
return {Src, Class, ~fcNan};
}

if (Pred == FCmpInst::FCMP_ONE)
return {Src, ~fcNan, RHSClass};

if (Pred == FCmpInst::FCMP_UNE)
return {Src, fcAllFlags, RHSClass};

if (IsNegativeRHS) {
// TODO: Handle fneg(fabs)
if (IsFabs) {
// fabs(x) o> -k -> fcmp ord x, x
// fabs(x) u> -k -> true
// fabs(x) o< -k -> false
// fabs(x) u< -k -> fcmp uno x, x
switch (Pred) {
case FCmpInst::FCMP_OGT:
case FCmpInst::FCMP_OGE:
return {Src, ~fcNan, fcNan};
case FCmpInst::FCMP_UGT:
case FCmpInst::FCMP_UGE:
return {Src, fcAllFlags, fcNone};
case FCmpInst::FCMP_OLT:
case FCmpInst::FCMP_OLE:
return {Src, fcNone, fcAllFlags};
case FCmpInst::FCMP_ULT:
case FCmpInst::FCMP_ULE:
return {Src, fcNan, ~fcNan};
default:
break;
}

return {nullptr, fcAllFlags, fcAllFlags};
}

FPClassTest ClassesLE = fcNegInf | fcNegNormal;
FPClassTest ClassesGE = fcPositive | fcNegZero | fcNegSubnormal;

if (ConstRHS->isDenormal())
ClassesLE |= fcNegSubnormal;
else
ClassesGE |= fcNegNormal;

switch (Pred) {
case FCmpInst::FCMP_OGT:
case FCmpInst::FCMP_OGE:
return {Src, ClassesGE, ~ClassesGE | RHSClass};
case FCmpInst::FCMP_UGT:
case FCmpInst::FCMP_UGE:
return {Src, ClassesGE | fcNan, ~(ClassesGE | fcNan) | RHSClass};
case FCmpInst::FCMP_OLT:
case FCmpInst::FCMP_OLE:
return {Src, ClassesLE, ~ClassesLE | RHSClass};
case FCmpInst::FCMP_ULT:
case FCmpInst::FCMP_ULE:
return {Src, ClassesLE | fcNan, ~(ClassesLE | fcNan) | RHSClass};
default:
break;
}
} else if (IsPositiveRHS) {
FPClassTest ClassesGE = fcPosNormal | fcPosInf;
FPClassTest ClassesLE = fcNegative | fcPosZero | fcPosNormal;
if (ConstRHS->isDenormal())
ClassesGE |= fcPosNormal;
else
ClassesLE |= fcPosSubnormal;

if (IsFabs) {
ClassesGE = llvm::inverse_fabs(ClassesGE);
ClassesLE = llvm::inverse_fabs(ClassesLE);
}
Comment on lines +4346 to +4349
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One can expect that IsPositiveRHS should be symmetrical to IsNegativeRHS but here their implementations are different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's different because this is the case that isn't rooted at 0. We need to consider the values between 0 and the absolute value of the constant


switch (Pred) {
case FCmpInst::FCMP_OGT:
case FCmpInst::FCMP_OGE:
return {Src, ClassesGE, ~ClassesGE | RHSClass};
case FCmpInst::FCMP_UGT:
case FCmpInst::FCMP_UGE:
return {Src, ClassesGE | fcNan, ~(ClassesGE | fcNan) | RHSClass};
case FCmpInst::FCMP_OLT:
case FCmpInst::FCMP_OLE:
return {Src, ClassesLE, ~ClassesLE | RHSClass};
case FCmpInst::FCMP_ULT:
case FCmpInst::FCMP_ULE:
return {Src, ClassesLE | fcNan, ~(ClassesLE | fcNan) | RHSClass};
default:
break;
}
}

return {nullptr, fcAllFlags, fcAllFlags};
}

std::tuple<Value *, FPClassTest, FPClassTest>
llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
Value *RHS, bool LookThroughSrc) {
const APFloat *ConstRHS;
if (!match(RHS, m_APFloatAllowUndef(ConstRHS)))
return {nullptr, fcAllFlags, fcNone};
return fcmpImpliesClass(Pred, F, LHS, ConstRHS, LookThroughSrc);
}

static FPClassTest computeKnownFPClassFromAssumes(const Value *V,
const SimplifyQuery &Q) {
FPClassTest KnownFromAssume = fcAllFlags;
Expand All @@ -4269,18 +4402,21 @@ static FPClassTest computeKnownFPClassFromAssumes(const Value *V,
Value *LHS, *RHS;
uint64_t ClassVal = 0;
if (match(I->getArgOperand(0), m_FCmp(Pred, m_Value(LHS), m_Value(RHS)))) {
auto [TestedValue, TestedMask] =
fcmpToClassTest(Pred, *F, LHS, RHS, true);
// First see if we can fold in fabs/fneg into the test.
if (TestedValue == V)
KnownFromAssume &= TestedMask;
else {
// Try again without the lookthrough if we found a different source
// value.
auto [TestedValue, TestedMask] =
fcmpToClassTest(Pred, *F, LHS, RHS, false);
if (TestedValue == V)
KnownFromAssume &= TestedMask;
const APFloat *CRHS;
if (match(RHS, m_APFloat(CRHS))) {
// First see if we can fold in fabs/fneg into the test.
auto [CmpVal, MaskIfTrue, MaskIfFalse] =
fcmpImpliesClass(Pred, *F, LHS, CRHS, true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you make just one call to fcmpImpliesClass, passing in LHS != V for the LookThroughSrc argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this in an earlier version of fcmpToClassTest but it missed some cases. Seems to not if I make this change, but I think it will still miss something

if (CmpVal == V)
KnownFromAssume &= MaskIfTrue;
else {
// Try again without the lookthrough if we found a different source
// value.
auto [CmpVal, MaskIfTrue, MaskIfFalse] =
fcmpImpliesClass(Pred, *F, LHS, CRHS, false);
if (CmpVal == V)
KnownFromAssume &= MaskIfTrue;
}
}
} else if (match(I->getArgOperand(0),
m_Intrinsic<Intrinsic::is_fpclass>(
Expand Down Expand Up @@ -4428,7 +4564,8 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
FPClassTest FilterRHS = fcAllFlags;

Value *TestedValue = nullptr;
FPClassTest TestedMask = fcNone;
FPClassTest MaskIfTrue = fcAllFlags;
FPClassTest MaskIfFalse = fcAllFlags;
uint64_t ClassVal = 0;
const Function *F = cast<Instruction>(Op)->getFunction();
CmpInst::Predicate Pred;
Expand All @@ -4440,20 +4577,22 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
// TODO: In some degenerate cases we can infer something if we try again
// without looking through sign operations.
bool LookThroughFAbsFNeg = CmpLHS != LHS && CmpLHS != RHS;
std::tie(TestedValue, TestedMask) =
fcmpToClassTest(Pred, *F, CmpLHS, CmpRHS, LookThroughFAbsFNeg);
std::tie(TestedValue, MaskIfTrue, MaskIfFalse) =
fcmpImpliesClass(Pred, *F, CmpLHS, CmpRHS, LookThroughFAbsFNeg);
} else if (match(Cond,
m_Intrinsic<Intrinsic::is_fpclass>(
m_Value(TestedValue), m_ConstantInt(ClassVal)))) {
TestedMask = static_cast<FPClassTest>(ClassVal);
FPClassTest TestedMask = static_cast<FPClassTest>(ClassVal);
MaskIfTrue = TestedMask;
MaskIfFalse = ~TestedMask;
}

if (TestedValue == LHS) {
// match !isnan(x) ? x : y
FilterLHS = TestedMask;
} else if (TestedValue == RHS) {
FilterLHS = MaskIfTrue;
} else if (TestedValue == RHS) { // && IsExactClass
// match !isnan(x) ? y : x
FilterRHS = ~TestedMask;
FilterRHS = MaskIfFalse;
}

KnownFPClass Known2;
Expand Down
Loading