Skip to content

Commit 0245bc8

Browse files
committed
[GlobalISel] Add computeFPClass to GlobaISelValueTracking
1 parent 37deb09 commit 0245bc8

File tree

8 files changed

+1662
-4
lines changed

8 files changed

+1662
-4
lines changed

llvm/include/llvm/CodeGen/GlobalISel/GISelValueTracking.h

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@
1414
#ifndef LLVM_CODEGEN_GLOBALISEL_GISELVALUETRACKING_H
1515
#define LLVM_CODEGEN_GLOBALISEL_GISELVALUETRACKING_H
1616

17+
#include "llvm/ADT/APFloat.h"
1718
#include "llvm/ADT/DenseMap.h"
1819
#include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
1920
#include "llvm/CodeGen/MachineFunctionPass.h"
2021
#include "llvm/CodeGen/Register.h"
22+
#include "llvm/IR/InstrTypes.h"
2123
#include "llvm/InitializePasses.h"
2224
#include "llvm/Support/KnownBits.h"
25+
#include "llvm/Support/KnownFPClass.h"
2326

2427
namespace llvm {
2528

@@ -41,6 +44,64 @@ class GISelValueTracking : public GISelChangeObserver {
4144
unsigned computeNumSignBitsMin(Register Src0, Register Src1,
4245
const APInt &DemandedElts, unsigned Depth = 0);
4346

47+
/// Returns a pair of values, which if passed to llvm.is.fpclass, returns the
48+
/// same result as an fcmp with the given operands.
49+
///
50+
/// If \p LookThroughSrc is true, consider the input value when computing the
51+
/// mask.
52+
///
53+
/// If \p LookThroughSrc is false, ignore the source value (i.e. the first
54+
/// pair element will always be LHS.
55+
std::pair<Register, FPClassTest> fcmpToClassTest(CmpInst::Predicate Pred,
56+
const MachineFunction &MF,
57+
Register LHS, Value *RHS,
58+
bool LookThroughSrc = true);
59+
std::pair<Register, FPClassTest> fcmpToClassTest(CmpInst::Predicate Pred,
60+
const MachineFunction &MF,
61+
Register LHS,
62+
const APFloat *ConstRHS,
63+
bool LookThroughSrc = true);
64+
65+
/// Compute the possible floating-point classes that \p LHS could be based on
66+
/// fcmp \Pred \p LHS, \p RHS.
67+
///
68+
/// \returns { TestedValue, ClassesIfTrue, ClassesIfFalse }
69+
///
70+
/// If the compare returns an exact class test, ClassesIfTrue ==
71+
/// ~ClassesIfFalse
72+
///
73+
/// This is a less exact version of fcmpToClassTest (e.g. fcmpToClassTest will
74+
/// only succeed for a test of x > 0 implies positive, but not x > 1).
75+
///
76+
/// If \p LookThroughSrc is true, consider the input value when computing the
77+
/// mask. This may look through sign bit operations.
78+
///
79+
/// If \p LookThroughSrc is false, ignore the source value (i.e. the first
80+
/// pair element will always be LHS.
81+
///
82+
std::tuple<Register, FPClassTest, FPClassTest>
83+
fcmpImpliesClass(CmpInst::Predicate Pred, const MachineFunction &MF,
84+
Register LHS, Register RHS, bool LookThroughSrc = true);
85+
std::tuple<Register, FPClassTest, FPClassTest>
86+
fcmpImpliesClass(CmpInst::Predicate Pred, const MachineFunction &MF,
87+
Register LHS, FPClassTest RHS, bool LookThroughSrc = true);
88+
std::tuple<Register, FPClassTest, FPClassTest>
89+
fcmpImpliesClass(CmpInst::Predicate Pred, const MachineFunction &MF,
90+
Register LHS, const APFloat &RHS,
91+
bool LookThroughSrc = true);
92+
93+
void computeKnownFPClass(Register R, KnownFPClass &Known,
94+
FPClassTest InterestedClasses, unsigned Depth);
95+
96+
void computeKnownFPClassForFPTrunc(const MachineInstr &MI,
97+
const APInt &DemandedElts,
98+
FPClassTest InterestedClasses,
99+
KnownFPClass &Known, unsigned Depth);
100+
101+
void computeKnownFPClass(Register R, const APInt &DemandedElts,
102+
FPClassTest InterestedClasses, KnownFPClass &Known,
103+
unsigned Depth);
104+
44105
public:
45106
GISelValueTracking(MachineFunction &MF, unsigned MaxDepth = 6);
46107
virtual ~GISelValueTracking() = default;
@@ -86,6 +147,34 @@ class GISelValueTracking : public GISelChangeObserver {
86147
/// \return The known alignment for the pointer-like value \p R.
87148
Align computeKnownAlignment(Register R, unsigned Depth = 0);
88149

150+
/// Determine which floating-point classes are valid for \p V, and return them
151+
/// in KnownFPClass bit sets.
152+
///
153+
/// This function is defined on values with floating-point type, values
154+
/// vectors of floating-point type, and arrays of floating-point type.
155+
156+
/// \p InterestedClasses is a compile time optimization hint for which
157+
/// floating point classes should be queried. Queries not specified in \p
158+
/// InterestedClasses should be reliable if they are determined during the
159+
/// query.
160+
KnownFPClass computeKnownFPClass(Register R, const APInt &DemandedElts,
161+
FPClassTest InterestedClasses,
162+
unsigned Depth);
163+
164+
KnownFPClass computeKnownFPClass(Register R,
165+
FPClassTest InterestedClasses = fcAllFlags,
166+
unsigned Depth = 0);
167+
168+
/// Wrapper to account for known fast math flags at the use instruction.
169+
KnownFPClass computeKnownFPClass(Register R, const APInt &DemandedElts,
170+
uint32_t Flags,
171+
FPClassTest InterestedClasses,
172+
unsigned Depth);
173+
174+
KnownFPClass computeKnownFPClass(Register R, uint32_t Flags,
175+
FPClassTest InterestedClasses,
176+
unsigned Depth);
177+
89178
// Observer API. No-op for non-caching implementation.
90179
void erasingInstr(MachineInstr &MI) override {}
91180
void createdInstr(MachineInstr &MI) override {}

llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define LLVM_CODEGEN_GLOBALISEL_MIPATTERNMATCH_H
1515

1616
#include "llvm/ADT/APInt.h"
17+
#include "llvm/ADT/FloatingPointMode.h"
1718
#include "llvm/CodeGen/GlobalISel/Utils.h"
1819
#include "llvm/CodeGen/MachineRegisterInfo.h"
1920
#include "llvm/IR/InstrTypes.h"
@@ -393,6 +394,7 @@ inline bind_ty<const MachineInstr *> m_MInstr(const MachineInstr *&MI) {
393394
inline bind_ty<LLT> m_Type(LLT &Ty) { return Ty; }
394395
inline bind_ty<CmpInst::Predicate> m_Pred(CmpInst::Predicate &P) { return P; }
395396
inline operand_type_match m_Pred() { return operand_type_match(); }
397+
inline bind_ty<FPClassTest> m_FPClassTest(FPClassTest &T) { return T; }
396398

397399
template <typename BindTy> struct deferred_helper {
398400
static bool match(const MachineRegisterInfo &MRI, BindTy &VR, BindTy &V) {
@@ -762,6 +764,32 @@ struct CompareOp_match {
762764
}
763765
};
764766

767+
template <typename LHS_P, typename Test_P, unsigned Opcode>
768+
struct ClassifyOp_match {
769+
LHS_P L;
770+
Test_P T;
771+
772+
ClassifyOp_match(const LHS_P &LHS, const Test_P &Tst) : L(LHS), T(Tst) {}
773+
774+
template <typename OpTy>
775+
bool match(const MachineRegisterInfo &MRI, OpTy &&Op) {
776+
MachineInstr *TmpMI;
777+
if (!mi_match(Op, MRI, m_MInstr(TmpMI)) || TmpMI->getOpcode() != Opcode)
778+
return false;
779+
780+
Register LHS = TmpMI->getOperand(1).getReg();
781+
if (!L.match(MRI, LHS))
782+
return false;
783+
784+
FPClassTest TmpClass =
785+
static_cast<FPClassTest>(TmpMI->getOperand(2).getImm());
786+
if (T.match(MRI, TmpClass))
787+
return true;
788+
789+
return false;
790+
}
791+
};
792+
765793
template <typename Pred, typename LHS, typename RHS>
766794
inline CompareOp_match<Pred, LHS, RHS, TargetOpcode::G_ICMP>
767795
m_GICmp(const Pred &P, const LHS &L, const RHS &R) {
@@ -804,6 +832,13 @@ m_c_GFCmp(const Pred &P, const LHS &L, const RHS &R) {
804832
return CompareOp_match<Pred, LHS, RHS, TargetOpcode::G_FCMP, true>(P, L, R);
805833
}
806834

835+
/// Matches a floating point class test
836+
template <typename LHS, typename Test>
837+
inline ClassifyOp_match<LHS, Test, TargetOpcode::G_IS_FPCLASS>
838+
m_GIsFPClass(const LHS &L, const Test &T) {
839+
return ClassifyOp_match<LHS, Test, TargetOpcode::G_IS_FPCLASS>(L, T);
840+
}
841+
807842
// Helper for checking if a Reg is of specific type.
808843
struct CheckType {
809844
LLT Ty;

llvm/include/llvm/CodeGen/GlobalISel/Utils.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,10 @@ std::optional<APInt> getIConstantVRegVal(Register VReg,
183183
std::optional<int64_t> getIConstantVRegSExtVal(Register VReg,
184184
const MachineRegisterInfo &MRI);
185185

186+
/// If \p VReg is defined by a G_CONSTANT fits in uint64_t returns it.
187+
std::optional<uint64_t> getIConstantVRegZExtVal(Register VReg,
188+
const MachineRegisterInfo &MRI);
189+
186190
/// \p VReg is defined by a G_CONSTANT, return the corresponding value.
187191
const APInt &getIConstantFromReg(Register VReg, const MachineRegisterInfo &MRI);
188192

@@ -438,6 +442,17 @@ std::optional<int64_t> getIConstantSplatSExtVal(const Register Reg,
438442
std::optional<int64_t> getIConstantSplatSExtVal(const MachineInstr &MI,
439443
const MachineRegisterInfo &MRI);
440444

445+
/// \returns the scalar sign extended integral splat value of \p Reg if
446+
/// possible.
447+
std::optional<uint64_t>
448+
getIConstantSplatZExtVal(const Register Reg, const MachineRegisterInfo &MRI);
449+
450+
/// \returns the scalar sign extended integral splat value defined by \p MI if
451+
/// possible.
452+
std::optional<uint64_t>
453+
getIConstantSplatZExtVal(const MachineInstr &MI,
454+
const MachineRegisterInfo &MRI);
455+
441456
/// Returns a floating point scalar constant of a build vector splat if it
442457
/// exists. When \p AllowUndef == true some elements can be undef but not all.
443458
std::optional<FPValueAndVReg> getFConstantSplat(Register VReg,
@@ -654,6 +669,9 @@ class GIConstant {
654669
/// }
655670
/// provides low-level access.
656671
class GFConstant {
672+
using VecTy = SmallVector<APFloat>;
673+
using const_iterator = VecTy::const_iterator;
674+
657675
public:
658676
enum class GFConstantKind { Scalar, FixedVector, ScalableVector };
659677

@@ -671,6 +689,23 @@ class GFConstant {
671689
/// Returns the kind of of this constant, e.g, Scalar.
672690
GFConstantKind getKind() const { return Kind; }
673691

692+
const_iterator begin() const {
693+
assert(Kind != GFConstantKind::ScalableVector &&
694+
"Expected fixed vector or scalar constant");
695+
return Values.begin();
696+
}
697+
698+
const_iterator end() const {
699+
assert(Kind != GFConstantKind::ScalableVector &&
700+
"Expected fixed vector or scalar constant");
701+
return Values.end();
702+
}
703+
704+
size_t size() const {
705+
assert(Kind == GFConstantKind::FixedVector && "Expected fixed vector");
706+
return Values.size();
707+
}
708+
674709
/// Returns the value, if this constant is a scalar.
675710
APFloat getScalarValue() const;
676711

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
#include "llvm/Support/AtomicOrdering.h"
5252
#include "llvm/Support/Casting.h"
5353
#include "llvm/Support/ErrorHandling.h"
54+
#include "llvm/Support/KnownFPClass.h"
5455
#include <algorithm>
5556
#include <cassert>
5657
#include <climits>
@@ -4165,6 +4166,13 @@ class TargetLowering : public TargetLoweringBase {
41654166
const MachineRegisterInfo &MRI,
41664167
unsigned Depth = 0) const;
41674168

4169+
virtual void computeKnownFPClassForTargetInstr(GISelValueTracking &Analysis,
4170+
Register R,
4171+
KnownFPClass &Known,
4172+
const APInt &DemandedElts,
4173+
const MachineRegisterInfo &MRI,
4174+
unsigned Depth = 0) const;
4175+
41684176
/// Determine the known alignment for the pointer value \p R. This is can
41694177
/// typically be inferred from the number of low known 0 bits. However, for a
41704178
/// pointer with a non-integral address space, the alignment value may be

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "llvm/ADT/APFloat.h"
1616
#include "llvm/ADT/APInt.h"
1717
#include "llvm/ADT/ArrayRef.h"
18+
#include "llvm/ADT/FloatingPointMode.h"
1819
#include "llvm/ADT/STLExtras.h"
1920
#include "llvm/ADT/ScopeExit.h"
2021
#include "llvm/ADT/SmallPtrSet.h"

0 commit comments

Comments
 (0)