Skip to content

Revert "Revert "[InstCombine] Transform high latency, dependent FSQRT/FDIV into FMUL"" #123313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "InstCombineInternal.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/ValueTracking.h"
Expand Down Expand Up @@ -657,6 +658,94 @@ Instruction *InstCombinerImpl::foldPowiReassoc(BinaryOperator &I) {
return nullptr;
}

// If we have the following pattern,
// X = 1.0/sqrt(a)
// R1 = X * X
// R2 = a/sqrt(a)
// then this method collects all the instructions that match R1 and R2.
static bool getFSqrtDivOptPattern(Instruction *Div,
SmallPtrSetImpl<Instruction *> &R1,
SmallPtrSetImpl<Instruction *> &R2) {
Value *A;
if (match(Div, m_FDiv(m_FPOne(), m_Sqrt(m_Value(A)))) ||
match(Div, m_FDiv(m_SpecificFP(-1.0), m_Sqrt(m_Value(A))))) {
for (User *U : Div->users()) {
Instruction *I = cast<Instruction>(U);
if (match(I, m_FMul(m_Specific(Div), m_Specific(Div))))
R1.insert(I);
}

CallInst *CI = cast<CallInst>(Div->getOperand(1));
for (User *U : CI->users()) {
Instruction *I = cast<Instruction>(U);
if (match(I, m_FDiv(m_Specific(A), m_Sqrt(m_Specific(A)))))
R2.insert(I);
}
}
return !R1.empty() && !R2.empty();
}

// Check legality for transforming
// x = 1.0/sqrt(a)
// r1 = x * x;
// r2 = a/sqrt(a);
//
// TO
//
// r1 = 1/a
// r2 = sqrt(a)
// x = r1 * r2
// This transform works only when 'a' is known positive.
static bool isFSqrtDivToFMulLegal(Instruction *X,
SmallPtrSetImpl<Instruction *> &R1,
SmallPtrSetImpl<Instruction *> &R2) {
// Check if the required pattern for the transformation exists.
if (!getFSqrtDivOptPattern(X, R1, R2))
return false;

BasicBlock *BBx = X->getParent();
BasicBlock *BBr1 = (*R1.begin())->getParent();
BasicBlock *BBr2 = (*R2.begin())->getParent();

CallInst *FSqrt = cast<CallInst>(X->getOperand(1));
if (!FSqrt->hasAllowReassoc() || !FSqrt->hasNoNaNs() ||
!FSqrt->hasNoSignedZeros() || !FSqrt->hasNoInfs())
return false;

// We change x = 1/sqrt(a) to x = sqrt(a) * 1/a . This change isn't allowed
// by recip fp as it is strictly meant to transform ops of type a/b to
// a * 1/b. So, this can be considered as algebraic rewrite and reassoc flag
// has been used(rather abused)in the past for algebraic rewrites.
if (!X->hasAllowReassoc() || !X->hasAllowReciprocal() || !X->hasNoInfs())
return false;

// Check the constraints on X, R1 and R2 combined.
// fdiv instruction and one of the multiplications must reside in the same
// block. If not, the optimized code may execute more ops than before and
// this may hamper the performance.
if (BBx != BBr1 && BBx != BBr2)
return false;

// Check the constraints on instructions in R1.
if (any_of(R1, [BBr1](Instruction *I) {
// When you have multiple instructions residing in R1 and R2
// respectively, it's difficult to generate combinations of (R1,R2) and
// then check if we have the required pattern. So, for now, just be
// conservative.
return (I->getParent() != BBr1 || !I->hasAllowReassoc());
}))
return false;

// Check the constraints on instructions in R2.
return all_of(R2, [BBr2](Instruction *I) {
// When you have multiple instructions residing in R1 and R2
// respectively, it's difficult to generate combination of (R1,R2) and
// then check if we have the required pattern. So, for now, just be
// conservative.
return (I->getParent() == BBr2 && I->hasAllowReassoc());
});
}

Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
Value *Op0 = I.getOperand(0);
Value *Op1 = I.getOperand(1);
Expand Down Expand Up @@ -1913,6 +2002,75 @@ static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I);
}

// Change
// X = 1/sqrt(a)
// R1 = X * X
// R2 = a * X
//
// TO
//
// FDiv = 1/a
// FSqrt = sqrt(a)
// FMul = FDiv * FSqrt
// Replace Uses Of R1 With FDiv
// Replace Uses Of R2 With FSqrt
// Replace Uses Of X With FMul
static Instruction *
convertFSqrtDivIntoFMul(CallInst *CI, Instruction *X,
const SmallPtrSetImpl<Instruction *> &R1,
const SmallPtrSetImpl<Instruction *> &R2,
InstCombiner::BuilderTy &B, InstCombinerImpl *IC) {

B.SetInsertPoint(X);

// Have an instruction that is representative of all of instructions in R1 and
// get the most common fpmath metadata and fast-math flags on it.
Value *SqrtOp = CI->getArgOperand(0);
auto *FDiv = cast<Instruction>(
B.CreateFDiv(ConstantFP::get(X->getType(), 1.0), SqrtOp));
auto *R1FPMathMDNode = (*R1.begin())->getMetadata(LLVMContext::MD_fpmath);
FastMathFlags R1FMF = (*R1.begin())->getFastMathFlags(); // Common FMF
for (Instruction *I : R1) {
R1FPMathMDNode = MDNode::getMostGenericFPMath(
R1FPMathMDNode, I->getMetadata(LLVMContext::MD_fpmath));
R1FMF &= I->getFastMathFlags();
IC->replaceInstUsesWith(*I, FDiv);
IC->eraseInstFromFunction(*I);
}
FDiv->setMetadata(LLVMContext::MD_fpmath, R1FPMathMDNode);
FDiv->copyFastMathFlags(R1FMF);

// Have a single sqrt call instruction that is representative of all of
// instructions in R2 and get the most common fpmath metadata and fast-math
// flags on it.
auto *FSqrt = cast<CallInst>(CI->clone());
FSqrt->insertBefore(CI);
auto *R2FPMathMDNode = (*R2.begin())->getMetadata(LLVMContext::MD_fpmath);
FastMathFlags R2FMF = (*R2.begin())->getFastMathFlags(); // Common FMF
for (Instruction *I : R2) {
R2FPMathMDNode = MDNode::getMostGenericFPMath(
R2FPMathMDNode, I->getMetadata(LLVMContext::MD_fpmath));
R2FMF &= I->getFastMathFlags();
IC->replaceInstUsesWith(*I, FSqrt);
IC->eraseInstFromFunction(*I);
}
FSqrt->setMetadata(LLVMContext::MD_fpmath, R2FPMathMDNode);
FSqrt->copyFastMathFlags(R2FMF);

Instruction *FMul;
// If X = -1/sqrt(a) initially,then FMul = -(FDiv * FSqrt)
if (match(X, m_FDiv(m_SpecificFP(-1.0), m_Specific(CI)))) {
Value *Mul = B.CreateFMul(FDiv, FSqrt);
FMul = cast<Instruction>(B.CreateFNeg(Mul));
} else
FMul = cast<Instruction>(B.CreateFMul(FDiv, FSqrt));
FMul->copyMetadata(*X);
FMul->copyFastMathFlags(FastMathFlags::intersectRewrite(R1FMF, R2FMF) |
FastMathFlags::unionValue(R1FMF, R2FMF));
IC->replaceInstUsesWith(*X, FMul);
return IC->eraseInstFromFunction(*X);
}

Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
Module *M = I.getModule();

Expand All @@ -1937,6 +2095,24 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
return R;

Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);

// Convert
// x = 1.0/sqrt(a)
// r1 = x * x;
// r2 = a/sqrt(a);
//
// TO
//
// r1 = 1/a
// r2 = sqrt(a)
// x = r1 * r2
SmallPtrSet<Instruction *, 2> R1, R2;
if (isFSqrtDivToFMulLegal(&I, R1, R2)) {
CallInst *CI = cast<CallInst>(I.getOperand(1));
if (Instruction *D = convertFSqrtDivIntoFMul(CI, &I, R1, R2, Builder, this))
return D;
}

if (isa<Constant>(Op0))
if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
if (Instruction *R = FoldOpIntoSelect(I, SI))
Expand Down
Loading
Loading