Skip to content

Commit 4231162

Browse files
committed
[InstCombine] Transform high latency, dependent FSQRT/FDIV into FMUL
The proposed patch, in general, tries to transform the below code sequence: x = 1.0 / sqrt (a); r1 = x * x; // same as 1.0 / a r2 = a / sqrt(a); // same as sqrt (a) TO (If x, r1 and r2 are all used further in the code) tmp1 = 1.0 / a tmp2 = sqrt (a) tmp3 = tmp1 * tmp2 x = tmp3 r1 = tmp1 r2 = tmp2 The transform tries to make high latency sqrt and div operations independent and also saves on one multiplication. The patch was tested with SPEC17 suite with cpu=neoverse-v2. The performance uplift achieved was: 544.nab_r ~4% No other regressions were observed. Also, no compile time differences were observed with the patch. Closes #54652
1 parent 86b1b06 commit 4231162

File tree

2 files changed

+807
-0
lines changed

2 files changed

+807
-0
lines changed

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "InstCombineInternal.h"
1515
#include "llvm/ADT/APInt.h"
16+
#include "llvm/ADT/SmallPtrSet.h"
1617
#include "llvm/ADT/SmallVector.h"
1718
#include "llvm/Analysis/InstructionSimplify.h"
1819
#include "llvm/Analysis/ValueTracking.h"
@@ -666,6 +667,94 @@ Instruction *InstCombinerImpl::foldPowiReassoc(BinaryOperator &I) {
666667
return nullptr;
667668
}
668669

670+
// If we have the following pattern,
671+
// X = 1.0/sqrt(a)
672+
// R1 = X * X
673+
// R2 = a/sqrt(a)
674+
// then this method collects all the instructions that match R1 and R2.
675+
static bool getFSqrtDivOptPattern(Instruction *Div,
676+
SmallPtrSetImpl<Instruction *> &R1,
677+
SmallPtrSetImpl<Instruction *> &R2) {
678+
Value *A;
679+
if (match(Div, m_FDiv(m_FPOne(), m_Sqrt(m_Value(A)))) ||
680+
match(Div, m_FDiv(m_SpecificFP(-1.0), m_Sqrt(m_Value(A))))) {
681+
for (User *U : Div->users()) {
682+
Instruction *I = cast<Instruction>(U);
683+
if (match(I, m_FMul(m_Specific(Div), m_Specific(Div))))
684+
R1.insert(I);
685+
}
686+
687+
CallInst *CI = cast<CallInst>(Div->getOperand(1));
688+
for (User *U : CI->users()) {
689+
Instruction *I = cast<Instruction>(U);
690+
if (match(I, m_FDiv(m_Specific(A), m_Sqrt(m_Specific(A)))))
691+
R2.insert(I);
692+
}
693+
}
694+
return !R1.empty() && !R2.empty();
695+
}
696+
697+
// Check legality for transforming
698+
// x = 1.0/sqrt(a)
699+
// r1 = x * x;
700+
// r2 = a/sqrt(a);
701+
//
702+
// TO
703+
//
704+
// r1 = 1/a
705+
// r2 = sqrt(a)
706+
// x = r1 * r2
707+
// This transform works only when 'a' is known positive.
708+
static bool isFSqrtDivToFMulLegal(Instruction *X,
709+
SmallPtrSetImpl<Instruction *> &R1,
710+
SmallPtrSetImpl<Instruction *> &R2) {
711+
// Check if the required pattern for the transformation exists.
712+
if (!getFSqrtDivOptPattern(X, R1, R2))
713+
return false;
714+
715+
BasicBlock *BBx = X->getParent();
716+
BasicBlock *BBr1 = (*R1.begin())->getParent();
717+
BasicBlock *BBr2 = (*R2.begin())->getParent();
718+
719+
CallInst *FSqrt = cast<CallInst>(X->getOperand(1));
720+
if (!FSqrt->hasAllowReassoc() || !FSqrt->hasNoNaNs() ||
721+
!FSqrt->hasNoSignedZeros() || !FSqrt->hasNoInfs())
722+
return false;
723+
724+
// We change x = 1/sqrt(a) to x = sqrt(a) * 1/a . This change isn't allowed
725+
// by recip fp as it is strictly meant to transform ops of type a/b to
726+
// a * 1/b. So, this can be considered as algebraic rewrite and reassoc flag
727+
// has been used(rather abused)in the past for algebraic rewrites.
728+
if (!X->hasAllowReassoc() || !X->hasAllowReciprocal() || !X->hasNoInfs())
729+
return false;
730+
731+
// Check the constraints on X, R1 and R2 combined.
732+
// fdiv instruction and one of the multiplications must reside in the same
733+
// block. If not, the optimized code may execute more ops than before and
734+
// this may hamper the performance.
735+
if (BBx != BBr1 && BBx != BBr2)
736+
return false;
737+
738+
// Check the constraints on instructions in R1.
739+
if (any_of(R1, [BBr1](Instruction *I) {
740+
// When you have multiple instructions residing in R1 and R2
741+
// respectively, it's difficult to generate combinations of (R1,R2) and
742+
// then check if we have the required pattern. So, for now, just be
743+
// conservative.
744+
return (I->getParent() != BBr1 || !I->hasAllowReassoc());
745+
}))
746+
return false;
747+
748+
// Check the constraints on instructions in R2.
749+
return all_of(R2, [BBr2](Instruction *I) {
750+
// When you have multiple instructions residing in R1 and R2
751+
// respectively, it's difficult to generate combination of (R1,R2) and
752+
// then check if we have the required pattern. So, for now, just be
753+
// conservative.
754+
return (I->getParent() == BBr2 && I->hasAllowReassoc());
755+
});
756+
}
757+
669758
Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
670759
Value *Op0 = I.getOperand(0);
671760
Value *Op1 = I.getOperand(1);
@@ -1917,6 +2006,75 @@ static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
19172006
return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I);
19182007
}
19192008

2009+
// Change
2010+
// X = 1/sqrt(a)
2011+
// R1 = X * X
2012+
// R2 = a * X
2013+
//
2014+
// TO
2015+
//
2016+
// FDiv = 1/a
2017+
// FSqrt = sqrt(a)
2018+
// FMul = FDiv * FSqrt
2019+
// Replace Uses Of R1 With FDiv
2020+
// Replace Uses Of R2 With FSqrt
2021+
// Replace Uses Of X With FMul
2022+
static Instruction *
2023+
convertFSqrtDivIntoFMul(CallInst *CI, Instruction *X,
2024+
const SmallPtrSetImpl<Instruction *> &R1,
2025+
const SmallPtrSetImpl<Instruction *> &R2,
2026+
InstCombiner::BuilderTy &B, InstCombinerImpl *IC) {
2027+
2028+
B.SetInsertPoint(X);
2029+
2030+
// Have an instruction that is representative of all of instructions in R1 and
2031+
// get the most common fpmath metadata and fast-math flags on it.
2032+
Value *SqrtOp = CI->getArgOperand(0);
2033+
auto *FDiv = cast<Instruction>(
2034+
B.CreateFDiv(ConstantFP::get(X->getType(), 1.0), SqrtOp));
2035+
auto *R1FPMathMDNode = (*R1.begin())->getMetadata(LLVMContext::MD_fpmath);
2036+
FastMathFlags R1FMF = (*R1.begin())->getFastMathFlags(); // Common FMF
2037+
for (Instruction *I : R1) {
2038+
R1FPMathMDNode = MDNode::getMostGenericFPMath(
2039+
R1FPMathMDNode, I->getMetadata(LLVMContext::MD_fpmath));
2040+
R1FMF &= I->getFastMathFlags();
2041+
IC->replaceInstUsesWith(*I, FDiv);
2042+
IC->eraseInstFromFunction(*I);
2043+
}
2044+
FDiv->setMetadata(LLVMContext::MD_fpmath, R1FPMathMDNode);
2045+
FDiv->copyFastMathFlags(R1FMF);
2046+
2047+
// Have a single sqrt call instruction that is representative of all of
2048+
// instructions in R2 and get the most common fpmath metadata and fast-math
2049+
// flags on it.
2050+
auto *FSqrt = cast<CallInst>(CI->clone());
2051+
FSqrt->insertBefore(CI);
2052+
auto *R2FPMathMDNode = (*R2.begin())->getMetadata(LLVMContext::MD_fpmath);
2053+
FastMathFlags R2FMF = (*R2.begin())->getFastMathFlags(); // Common FMF
2054+
for (Instruction *I : R2) {
2055+
R2FPMathMDNode = MDNode::getMostGenericFPMath(
2056+
R2FPMathMDNode, I->getMetadata(LLVMContext::MD_fpmath));
2057+
R2FMF &= I->getFastMathFlags();
2058+
IC->replaceInstUsesWith(*I, FSqrt);
2059+
IC->eraseInstFromFunction(*I);
2060+
}
2061+
FSqrt->setMetadata(LLVMContext::MD_fpmath, R2FPMathMDNode);
2062+
FSqrt->copyFastMathFlags(R2FMF);
2063+
2064+
Instruction *FMul;
2065+
// If X = -1/sqrt(a) initially,then FMul = -(FDiv * FSqrt)
2066+
if (match(X, m_FDiv(m_SpecificFP(-1.0), m_Specific(CI)))) {
2067+
Value *Mul = B.CreateFMul(FDiv, FSqrt);
2068+
FMul = cast<Instruction>(B.CreateFNeg(Mul));
2069+
} else
2070+
FMul = cast<Instruction>(B.CreateFMul(FDiv, FSqrt));
2071+
FMul->copyMetadata(*X);
2072+
FMul->copyFastMathFlags(FastMathFlags::intersectRewrite(R1FMF, R2FMF) |
2073+
FastMathFlags::unionValue(R1FMF, R2FMF));
2074+
IC->replaceInstUsesWith(*X, FMul);
2075+
return IC->eraseInstFromFunction(*X);
2076+
}
2077+
19202078
Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
19212079
Module *M = I.getModule();
19222080

@@ -1941,6 +2099,24 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
19412099
return R;
19422100

19432101
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
2102+
2103+
// Convert
2104+
// x = 1.0/sqrt(a)
2105+
// r1 = x * x;
2106+
// r2 = a/sqrt(a);
2107+
//
2108+
// TO
2109+
//
2110+
// r1 = 1/a
2111+
// r2 = sqrt(a)
2112+
// x = r1 * r2
2113+
SmallPtrSet<Instruction *, 2> R1, R2;
2114+
if (isFSqrtDivToFMulLegal(&I, R1, R2)) {
2115+
CallInst *CI = cast<CallInst>(I.getOperand(1));
2116+
if (Instruction *D = convertFSqrtDivIntoFMul(CI, &I, R1, R2, Builder, this))
2117+
return D;
2118+
}
2119+
19442120
if (isa<Constant>(Op0))
19452121
if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
19462122
if (Instruction *R = FoldOpIntoSelect(I, SI))

0 commit comments

Comments
 (0)