Skip to content

Commit e00d139

Browse files
committed
[InstCombine] Transform high latency, dependent FSQRT/FDIV into FMUL
The proposed patch, in general, tries to transform the below code sequence: x = 1.0 / sqrt (a); r1 = x * x; // same as 1.0 / a r2 = a / sqrt(a); // same as sqrt (a) TO (If x, r1 and r2 are all used further in the code) tmp1 = 1.0 / a tmp2 = sqrt (a) tmp3 = tmp1 * tmp2 x = tmp3 r1 = tmp1 r2 = tmp2 The transform tries to make high latency sqrt and div operations independent and also saves on one multiplication. The patch was tested with SPEC17 suite with cpu=neoverse-v2. The performance uplift achieved was: 544.nab_r ~4% No other regressions were observed. Also, no compile time differences were observed with the patch. Closes #54652
1 parent e05c1b4 commit e00d139

File tree

2 files changed

+591
-0
lines changed

2 files changed

+591
-0
lines changed

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,89 @@ Instruction *InstCombinerImpl::foldPowiReassoc(BinaryOperator &I) {
626626
return nullptr;
627627
}
628628

629+
// Check legality for transforming
630+
// x = 1.0/sqrt(a)
631+
// r1 = x * x;
632+
// r2 = a/sqrt(a);
633+
//
634+
// TO
635+
//
636+
// r1 = 1/a
637+
// r2 = sqrt(a)
638+
// x = r1 * r2
639+
static bool isFSqrtDivToFMulLegal(Instruction *X,
640+
const SmallVectorImpl<Instruction *> &R1,
641+
const SmallVectorImpl<Instruction *> &R2) {
642+
BasicBlock *BBx = X->getParent();
643+
BasicBlock *BBr1 = R1[0]->getParent();
644+
BasicBlock *BBr2 = R2[0]->getParent();
645+
646+
CallInst *FSqrt = cast<CallInst>(X->getOperand(1));
647+
if (!(FSqrt->hasAllowReassoc() && FSqrt->hasNoNaNs() &&
648+
FSqrt->hasNoSignedZeros() && FSqrt->hasNoInfs()))
649+
return false;
650+
651+
// We change x = 1/sqrt(a) to x = sqrt(a) * 1/a . This change isn't allowed
652+
// by recip fp as it is strictly meant to transform ops of type a/b to
653+
// a * 1/b. So, this can be considered as algebraic rewrite and reassoc flag
654+
// has been used(rather abused)in the past for algebraic rewrites.
655+
if (!(X->hasAllowReassoc() && X->hasAllowReciprocal() && X->hasNoInfs()))
656+
return false;
657+
658+
// Check the constraints on instructions in R1.
659+
if (!all_of(R1, [BBr1](Instruction *I) {
660+
// When you have multiple instructions residing in R1 and R2
661+
// respectively, it's difficult to generate combinations of (R1,R2) and
662+
// then check if we have the required pattern. So, for now, just be
663+
// conservative.
664+
return (I->getParent() == BBr1 && I->hasAllowReassoc());
665+
}))
666+
return false;
667+
668+
// Check the constraints on instructions in R2.
669+
if (!all_of(R2, [BBr2](Instruction *I) {
670+
// When you have multiple instructions residing in R1 and R2
671+
// respectively, it's difficult to generate combination of (R1,R2) and
672+
// then check if we have the required pattern. So, for now, just be
673+
// conservative.
674+
return (I->getParent() == BBr2 && I->hasAllowReassoc());
675+
}))
676+
return false;
677+
678+
// Check the constraints on X, R1 and R2 combined.
679+
// fdiv instruction and one of the multiplications must reside in the same
680+
// block. If not, the optimized code may execute more ops than before and
681+
// this may hamper the performance.
682+
return (BBx == BBr1 || BBx == BBr2);
683+
}
684+
685+
static void getFSqrtDivOptPattern(Instruction *Div,
686+
SmallVectorImpl<Instruction *> &R1,
687+
SmallVectorImpl<Instruction *> &R2) {
688+
Value *A;
689+
if (match(Div, m_FDiv(m_FPOne(), m_Sqrt(m_Value(A)))) ||
690+
match(Div, m_FDiv(m_SpecificFP(-1.0), m_Sqrt(m_Value(A))))) {
691+
for (User *U : Div->users()) {
692+
Instruction *I = dyn_cast<Instruction>(U);
693+
if (!(I && I->getOpcode() == Instruction::FMul))
694+
continue;
695+
696+
if (match(I, m_FMul(m_Specific(Div), m_Specific(Div)))) {
697+
R1.push_back(I);
698+
continue;
699+
}
700+
}
701+
CallInst *CI = cast<CallInst>(Div->getOperand(1));
702+
for (User *U : CI->users()) {
703+
Instruction *I = dyn_cast<Instruction>(U);
704+
if (match(I, m_FDiv(m_Specific(A), m_Sqrt(m_Specific(A))))) {
705+
R2.push_back(I);
706+
continue;
707+
}
708+
}
709+
}
710+
}
711+
629712
Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
630713
Value *Op0 = I.getOperand(0);
631714
Value *Op1 = I.getOperand(1);
@@ -1796,6 +1879,35 @@ static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
17961879
return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I);
17971880
}
17981881

1882+
static Value *convertFSqrtDivIntoFMul(CallInst *CI, Instruction *X,
1883+
const SmallVectorImpl<Instruction *> &R1,
1884+
const SmallVectorImpl<Instruction *> &R2,
1885+
Value *SqrtOp,
1886+
InstCombiner::BuilderTy &B) {
1887+
// 1. synthesize tmp1 = 1/a and replace uses of r1
1888+
B.SetInsertPoint(X);
1889+
Value *Tmp1 =
1890+
B.CreateFDivFMF(ConstantFP::get(R1[0]->getType(), 1.0), SqrtOp, R1[0]);
1891+
for (auto *I : R1)
1892+
I->replaceAllUsesWith(Tmp1);
1893+
1894+
// 2. No need of synthesizing Tmp2 again. In this scenario, tmp2 = CI. Replace
1895+
// uses of r2 with tmp2
1896+
for (auto *I : R2)
1897+
I->replaceAllUsesWith(CI);
1898+
1899+
// 3. synthesize tmp3 = tmp1 * tmp2 . Replace uses of 'x' with tmp3
1900+
Value *Tmp3;
1901+
// If x = -1/sqrt(a) initially,then Tmp3 = -(Tmp1*tmp2)
1902+
if (match(X, m_FDiv(m_SpecificFP(-1.0), m_Specific(CI)))) {
1903+
Value *Mul = B.CreateFMul(Tmp1, CI);
1904+
Tmp3 = B.CreateFNegFMF(Mul, X);
1905+
} else
1906+
Tmp3 = B.CreateFMulFMF(Tmp1, CI, X);
1907+
1908+
return Tmp3;
1909+
}
1910+
17991911
Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
18001912
Module *M = I.getModule();
18011913

@@ -1820,6 +1932,26 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
18201932
return R;
18211933

18221934
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
1935+
1936+
// Convert
1937+
// x = 1.0/sqrt(a)
1938+
// r1 = x * x;
1939+
// r2 = a/sqrt(a);
1940+
//
1941+
// TO
1942+
//
1943+
// r1 = 1/a
1944+
// r2 = sqrt(a)
1945+
// x = r1 * r2
1946+
SmallVector<Instruction *, 2> R1, R2;
1947+
getFSqrtDivOptPattern(&I, R1, R2);
1948+
if (!R1.empty() && !R2.empty() && isFSqrtDivToFMulLegal(&I, R1, R2)) {
1949+
CallInst *CI = cast<CallInst>(I.getOperand(1));
1950+
Value *SqrtOp = CI->getArgOperand(0);
1951+
if (Value *D = convertFSqrtDivIntoFMul(CI, &I, R1, R2, SqrtOp, Builder))
1952+
return replaceInstUsesWith(I, D);
1953+
}
1954+
18231955
if (isa<Constant>(Op0))
18241956
if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
18251957
if (Instruction *R = FoldOpIntoSelect(I, SI))

0 commit comments

Comments
 (0)