Skip to content

[IRBuilder] Fold binary intrinsics #80743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 15, 2024
Merged
9 changes: 9 additions & 0 deletions llvm/include/llvm/Analysis/ConstantFolding.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
#include <stdint.h>

namespace llvm {

namespace Intrinsic {
using ID = unsigned;
}

class APInt;
template <typename T> class ArrayRef;
class CallBase;
Expand Down Expand Up @@ -186,6 +191,10 @@ Constant *ConstantFoldCall(const CallBase *Call, Function *F,
ArrayRef<Constant *> Operands,
const TargetLibraryInfo *TLI = nullptr);

Constant *ConstantFoldBinaryIntrinsic(Intrinsic::ID ID, Constant *LHS,
Constant *RHS, Type *Ty,
Instruction *FMFSource);

/// ConstantFoldLoadThroughBitcast - try to cast constant to destination type
/// returning null if unsuccessful. Can cast pointer to pointer or pointer to
/// integer and vice versa if their sizes are equal.
Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/Analysis/InstSimplifyFolder.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ class InstSimplifyFolder final : public IRBuilderFolder {
return simplifyCastInst(Op, V, DestTy, SQ);
}

Value *FoldBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, Value *RHS, Type *Ty,
Instruction *FMFSource) const override {
return simplifyBinaryIntrinsic(ID, Ty, LHS, RHS, SQ,
dyn_cast_if_present<CallBase>(FMFSource));
}

//===--------------------------------------------------------------------===//
// Cast/Conversion Operators
//===--------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions llvm/include/llvm/Analysis/InstructionSimplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ Value *simplifyExtractElementInst(Value *Vec, Value *Idx,
Value *simplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty,
const SimplifyQuery &Q);

/// Given operands for a BinaryIntrinsic, fold the result or return null.
Value *simplifyBinaryIntrinsic(Intrinsic::ID IID, Type *ReturnType, Value *Op0,
Value *Op1, const SimplifyQuery &Q,
const CallBase *Call);

/// Given operands for a ShuffleVectorInst, fold the result or return null.
/// See class ShuffleVectorInst for a description of the mask representation.
Value *simplifyShuffleVectorInst(Value *Op0, Value *Op1, ArrayRef<int> Mask,
Expand Down
9 changes: 9 additions & 0 deletions llvm/include/llvm/Analysis/TargetFolder.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,15 @@ class TargetFolder final : public IRBuilderFolder {
return nullptr;
}

Value *FoldBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, Value *RHS, Type *Ty,
Instruction *FMFSource) const override {
auto *C1 = dyn_cast<Constant>(LHS);
auto *C2 = dyn_cast<Constant>(RHS);
if (C1 && C2)
return ConstantFoldBinaryIntrinsic(ID, C1, C2, Ty, FMFSource);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't help constant folding :(

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, my bad. It is used to fold constrained fp intrinsics.

return nullptr;
}

//===--------------------------------------------------------------------===//
// Cast/Conversion Operators
//===--------------------------------------------------------------------===//
Expand Down
10 changes: 8 additions & 2 deletions llvm/include/llvm/IR/ConstantFolder.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/ConstantFold.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/IRBuilderFolder.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Operator.h"
Expand Down Expand Up @@ -89,7 +89,7 @@ class ConstantFolder final : public IRBuilderFolder {
}

Value *FoldUnOpFMF(Instruction::UnaryOps Opc, Value *V,
FastMathFlags FMF) const override {
FastMathFlags FMF) const override {
if (Constant *C = dyn_cast<Constant>(V))
return ConstantFoldUnaryInstruction(Opc, C);
return nullptr;
Expand Down Expand Up @@ -183,6 +183,12 @@ class ConstantFolder final : public IRBuilderFolder {
return nullptr;
}

Value *FoldBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, Value *RHS, Type *Ty,
Instruction *FMFSource) const override {
// Use TargetFolder or InstSimplifyFolder instead.
return nullptr;
}

//===--------------------------------------------------------------------===//
// Cast/Conversion Operators
//===--------------------------------------------------------------------===//
Expand Down
20 changes: 10 additions & 10 deletions llvm/include/llvm/IR/IRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -962,9 +962,9 @@ class IRBuilderBase {

/// Create a call to intrinsic \p ID with 2 operands which is mangled on the
/// first type.
CallInst *CreateBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, Value *RHS,
Instruction *FMFSource = nullptr,
const Twine &Name = "");
Value *CreateBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, Value *RHS,
Instruction *FMFSource = nullptr,
const Twine &Name = "");

/// Create a call to intrinsic \p ID with \p Args, mangled using \p Types. If
/// \p FMFSource is provided, copy fast-math-flags from that instruction to
Expand All @@ -983,7 +983,7 @@ class IRBuilderBase {
const Twine &Name = "");

/// Create call to the minnum intrinsic.
CallInst *CreateMinNum(Value *LHS, Value *RHS, const Twine &Name = "") {
Value *CreateMinNum(Value *LHS, Value *RHS, const Twine &Name = "") {
if (IsFPConstrained) {
return CreateConstrainedFPUnroundedBinOp(
Intrinsic::experimental_constrained_minnum, LHS, RHS, nullptr, Name);
Expand All @@ -993,7 +993,7 @@ class IRBuilderBase {
}

/// Create call to the maxnum intrinsic.
CallInst *CreateMaxNum(Value *LHS, Value *RHS, const Twine &Name = "") {
Value *CreateMaxNum(Value *LHS, Value *RHS, const Twine &Name = "") {
if (IsFPConstrained) {
return CreateConstrainedFPUnroundedBinOp(
Intrinsic::experimental_constrained_maxnum, LHS, RHS, nullptr, Name);
Expand All @@ -1003,19 +1003,19 @@ class IRBuilderBase {
}

/// Create call to the minimum intrinsic.
CallInst *CreateMinimum(Value *LHS, Value *RHS, const Twine &Name = "") {
Value *CreateMinimum(Value *LHS, Value *RHS, const Twine &Name = "") {
return CreateBinaryIntrinsic(Intrinsic::minimum, LHS, RHS, nullptr, Name);
}

/// Create call to the maximum intrinsic.
CallInst *CreateMaximum(Value *LHS, Value *RHS, const Twine &Name = "") {
Value *CreateMaximum(Value *LHS, Value *RHS, const Twine &Name = "") {
return CreateBinaryIntrinsic(Intrinsic::maximum, LHS, RHS, nullptr, Name);
}

/// Create call to the copysign intrinsic.
CallInst *CreateCopySign(Value *LHS, Value *RHS,
Instruction *FMFSource = nullptr,
const Twine &Name = "") {
Value *CreateCopySign(Value *LHS, Value *RHS,
Instruction *FMFSource = nullptr,
const Twine &Name = "") {
return CreateBinaryIntrinsic(Intrinsic::copysign, LHS, RHS, FMFSource,
Name);
}
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/IR/IRBuilderFolder.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ class IRBuilderFolder {
virtual Value *FoldCast(Instruction::CastOps Op, Value *V,
Type *DestTy) const = 0;

virtual Value *
FoldBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, Value *RHS, Type *Ty,
Instruction *FMFSource = nullptr) const = 0;

//===--------------------------------------------------------------------===//
// Cast/Conversion Operators
//===--------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions llvm/include/llvm/IR/NoFolder.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ class NoFolder final : public IRBuilderFolder {
return nullptr;
}

Value *FoldBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, Value *RHS, Type *Ty,
Instruction *FMFSource) const override {
return nullptr;
}

//===--------------------------------------------------------------------===//
// Cast/Conversion Operators
//===--------------------------------------------------------------------===//
Expand Down
138 changes: 83 additions & 55 deletions llvm/lib/Analysis/ConstantFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2529,12 +2529,73 @@ static Constant *evaluateCompare(const APFloat &Op1, const APFloat &Op2,
return nullptr;
}

static Constant *ConstantFoldScalarCall2(StringRef Name,
Intrinsic::ID IntrinsicID,
Type *Ty,
ArrayRef<Constant *> Operands,
const TargetLibraryInfo *TLI,
const CallBase *Call) {
static Constant *ConstantFoldLibCall2(StringRef Name, Type *Ty,
ArrayRef<Constant *> Operands,
const TargetLibraryInfo *TLI) {
if (!TLI)
return nullptr;

LibFunc Func = NotLibFunc;
if (!TLI->getLibFunc(Name, Func))
return nullptr;

const auto *Op1 = dyn_cast<ConstantFP>(Operands[0]);
if (!Op1)
return nullptr;

const auto *Op2 = dyn_cast<ConstantFP>(Operands[1]);
if (!Op2)
return nullptr;

const APFloat &Op1V = Op1->getValueAPF();
const APFloat &Op2V = Op2->getValueAPF();

switch (Func) {
default:
break;
case LibFunc_pow:
case LibFunc_powf:
case LibFunc_pow_finite:
case LibFunc_powf_finite:
if (TLI->has(Func))
return ConstantFoldBinaryFP(pow, Op1V, Op2V, Ty);
break;
case LibFunc_fmod:
case LibFunc_fmodf:
if (TLI->has(Func)) {
APFloat V = Op1->getValueAPF();
if (APFloat::opStatus::opOK == V.mod(Op2->getValueAPF()))
return ConstantFP::get(Ty->getContext(), V);
}
break;
case LibFunc_remainder:
case LibFunc_remainderf:
if (TLI->has(Func)) {
APFloat V = Op1->getValueAPF();
if (APFloat::opStatus::opOK == V.remainder(Op2->getValueAPF()))
return ConstantFP::get(Ty->getContext(), V);
}
break;
case LibFunc_atan2:
case LibFunc_atan2f:
// atan2(+/-0.0, +/-0.0) is known to raise an exception on some libm
// (Solaris), so we do not assume a known result for that.
if (Op1V.isZero() && Op2V.isZero())
return nullptr;
[[fallthrough]];
case LibFunc_atan2_finite:
case LibFunc_atan2f_finite:
if (TLI->has(Func))
return ConstantFoldBinaryFP(atan2, Op1V, Op2V, Ty);
break;
}

return nullptr;
}

static Constant *ConstantFoldIntrinsicCall2(Intrinsic::ID IntrinsicID, Type *Ty,
ArrayRef<Constant *> Operands,
const CallBase *Call) {
assert(Operands.size() == 2 && "Wrong number of operands.");

if (Ty->isFloatingPointTy()) {
Expand Down Expand Up @@ -2564,7 +2625,8 @@ static Constant *ConstantFoldScalarCall2(StringRef Name,
return nullptr;
const APFloat &Op2V = Op2->getValueAPF();

if (const auto *ConstrIntr = dyn_cast<ConstrainedFPIntrinsic>(Call)) {
if (const auto *ConstrIntr =
dyn_cast_if_present<ConstrainedFPIntrinsic>(Call)) {
RoundingMode RM = getEvaluationRoundingMode(ConstrIntr);
APFloat Res = Op1V;
APFloat::opStatus St;
Expand Down Expand Up @@ -2627,52 +2689,6 @@ static Constant *ConstantFoldScalarCall2(StringRef Name,
return ConstantFP::get(Ty->getContext(), Op1V * Op2V);
}

if (!TLI)
return nullptr;

LibFunc Func = NotLibFunc;
if (!TLI->getLibFunc(Name, Func))
return nullptr;

switch (Func) {
default:
break;
case LibFunc_pow:
case LibFunc_powf:
case LibFunc_pow_finite:
case LibFunc_powf_finite:
if (TLI->has(Func))
return ConstantFoldBinaryFP(pow, Op1V, Op2V, Ty);
break;
case LibFunc_fmod:
case LibFunc_fmodf:
if (TLI->has(Func)) {
APFloat V = Op1->getValueAPF();
if (APFloat::opStatus::opOK == V.mod(Op2->getValueAPF()))
return ConstantFP::get(Ty->getContext(), V);
}
break;
case LibFunc_remainder:
case LibFunc_remainderf:
if (TLI->has(Func)) {
APFloat V = Op1->getValueAPF();
if (APFloat::opStatus::opOK == V.remainder(Op2->getValueAPF()))
return ConstantFP::get(Ty->getContext(), V);
}
break;
case LibFunc_atan2:
case LibFunc_atan2f:
// atan2(+/-0.0, +/-0.0) is known to raise an exception on some libm
// (Solaris), so we do not assume a known result for that.
if (Op1V.isZero() && Op2V.isZero())
return nullptr;
[[fallthrough]];
case LibFunc_atan2_finite:
case LibFunc_atan2f_finite:
if (TLI->has(Func))
return ConstantFoldBinaryFP(atan2, Op1V, Op2V, Ty);
break;
}
} else if (auto *Op2C = dyn_cast<ConstantInt>(Operands[1])) {
switch (IntrinsicID) {
case Intrinsic::ldexp: {
Expand Down Expand Up @@ -3163,8 +3179,13 @@ static Constant *ConstantFoldScalarCall(StringRef Name,
if (Operands.size() == 1)
return ConstantFoldScalarCall1(Name, IntrinsicID, Ty, Operands, TLI, Call);

if (Operands.size() == 2)
return ConstantFoldScalarCall2(Name, IntrinsicID, Ty, Operands, TLI, Call);
if (Operands.size() == 2) {
if (Constant *FoldedLibCall =
ConstantFoldLibCall2(Name, Ty, Operands, TLI)) {
return FoldedLibCall;
}
return ConstantFoldIntrinsicCall2(IntrinsicID, Ty, Operands, Call);
}

if (Operands.size() == 3)
return ConstantFoldScalarCall3(Name, IntrinsicID, Ty, Operands, TLI, Call);
Expand Down Expand Up @@ -3371,6 +3392,13 @@ ConstantFoldStructCall(StringRef Name, Intrinsic::ID IntrinsicID,

} // end anonymous namespace

Constant *llvm::ConstantFoldBinaryIntrinsic(Intrinsic::ID ID, Constant *LHS,
Constant *RHS, Type *Ty,
Instruction *FMFSource) {
return ConstantFoldIntrinsicCall2(ID, Ty, {LHS, RHS},
dyn_cast_if_present<CallBase>(FMFSource));
}

Constant *llvm::ConstantFoldCall(const CallBase *Call, Function *F,
ArrayRef<Constant *> Operands,
const TargetLibraryInfo *TLI) {
Expand Down
Loading