Skip to content

Commit 4c84c82

Browse files
committed
[InstCombine] Transform high latency, dependent FSQRT/FDIV into FMUL
The proposed patch, in general, tries to transform the below code sequence: x = 1.0 / sqrt (a); r1 = x * x; // same as 1.0 / a r2 = a / sqrt(a); // same as sqrt (a) TO (If x, r1 and r2 are all used further in the code) tmp1 = 1.0 / a tmp2 = sqrt (a) tmp3 = tmp1 * tmp2 x = tmp3 r1 = tmp1 r2 = tmp2 The transform tries to make high latency sqrt and div operations independent and also saves on one multiplication. The patch was tested with SPEC17 suite with cpu=neoverse-v2. The performance uplift achieved was: 544.nab_r ~4% No other regressions were observed. Also, no compile time differences were observed with the patch. Closes #54652
1 parent e05c1b4 commit 4c84c82

File tree

2 files changed

+649
-0
lines changed

2 files changed

+649
-0
lines changed

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,88 @@ Instruction *InstCombinerImpl::foldPowiReassoc(BinaryOperator &I) {
626626
return nullptr;
627627
}
628628

629+
// Check legality for transforming
630+
// x = 1.0/sqrt(a)
631+
// r1 = x * x;
632+
// r2 = a/sqrt(a);
633+
//
634+
// TO
635+
//
636+
// r1 = 1/a
637+
// r2 = sqrt(a)
638+
// x = r1 * r2
639+
static bool isFSqrtDivToFMulLegal(Instruction *X, ArrayRef<Instruction *> R1,
640+
ArrayRef<Instruction *> R2) {
641+
BasicBlock *BBx = X->getParent();
642+
BasicBlock *BBr1 = R1[0]->getParent();
643+
BasicBlock *BBr2 = R2[0]->getParent();
644+
645+
CallInst *FSqrt = cast<CallInst>(X->getOperand(1));
646+
if (!FSqrt->hasAllowReassoc() || !FSqrt->hasNoNaNs() ||
647+
!FSqrt->hasNoSignedZeros() || !FSqrt->hasNoInfs())
648+
return false;
649+
650+
// We change x = 1/sqrt(a) to x = sqrt(a) * 1/a . This change isn't allowed
651+
// by recip fp as it is strictly meant to transform ops of type a/b to
652+
// a * 1/b. So, this can be considered as algebraic rewrite and reassoc flag
653+
// has been used(rather abused)in the past for algebraic rewrites.
654+
if (!X->hasAllowReassoc() || !X->hasAllowReciprocal() || !X->hasNoInfs())
655+
return false;
656+
657+
// Check the constraints on instructions in R1.
658+
if (any_of(R1, [BBr1](Instruction *I) {
659+
// When you have multiple instructions residing in R1 and R2
660+
// respectively, it's difficult to generate combinations of (R1,R2) and
661+
// then check if we have the required pattern. So, for now, just be
662+
// conservative.
663+
return (I->getParent() != BBr1 || !I->hasAllowReassoc());
664+
}))
665+
return false;
666+
667+
// Check the constraints on instructions in R2.
668+
if (any_of(R2, [BBr2](Instruction *I) {
669+
// When you have multiple instructions residing in R1 and R2
670+
// respectively, it's difficult to generate combination of (R1,R2) and
671+
// then check if we have the required pattern. So, for now, just be
672+
// conservative.
673+
return (I->getParent() != BBr2 || !I->hasAllowReassoc());
674+
}))
675+
return false;
676+
677+
// Check the constraints on X, R1 and R2 combined.
678+
// fdiv instruction and one of the multiplications must reside in the same
679+
// block. If not, the optimized code may execute more ops than before and
680+
// this may hamper the performance.
681+
return (BBx == BBr1 || BBx == BBr2);
682+
}
683+
684+
static void getFSqrtDivOptPattern(Instruction *Div,
685+
SmallVectorImpl<Instruction *> &R1,
686+
SmallVectorImpl<Instruction *> &R2) {
687+
Value *A;
688+
if (match(Div, m_FDiv(m_FPOne(), m_Sqrt(m_Value(A)))) ||
689+
match(Div, m_FDiv(m_SpecificFP(-1.0), m_Sqrt(m_Value(A))))) {
690+
for (User *U : Div->users()) {
691+
Instruction *I = dyn_cast<Instruction>(U);
692+
if (!I || I->getOpcode() != Instruction::FMul)
693+
continue;
694+
695+
if (match(I, m_FMul(m_Specific(Div), m_Specific(Div)))) {
696+
R1.push_back(I);
697+
continue;
698+
}
699+
}
700+
CallInst *CI = cast<CallInst>(Div->getOperand(1));
701+
for (User *U : CI->users()) {
702+
Instruction *I = cast<Instruction>(U);
703+
if (match(I, m_FDiv(m_Specific(A), m_Sqrt(m_Specific(A))))) {
704+
R2.push_back(I);
705+
continue;
706+
}
707+
}
708+
}
709+
}
710+
629711
Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
630712
Value *Op0 = I.getOperand(0);
631713
Value *Op1 = I.getOperand(1);
@@ -1796,6 +1878,64 @@ static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
17961878
return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I);
17971879
}
17981880

1881+
// Change
1882+
// X = 1/sqrt(a)
1883+
// R1 = X * X
1884+
// R2 = a * X
1885+
//
1886+
// TO
1887+
//
1888+
// Tmp1 = 1/a
1889+
// Tmp2 = sqrt(a)
1890+
// Tmp3 = Tmp1 * Tmp2
1891+
// Replace Uses Of R1 With Tmp1
1892+
// Replace Uses Of R2 With Tmp2
1893+
// Replace Uses Of X With Tmp3
1894+
static Value *convertFSqrtDivIntoFMul(CallInst *CI, Instruction *X,
1895+
ArrayRef<Instruction *> R1,
1896+
ArrayRef<Instruction *> R2, Value *SqrtOp,
1897+
InstCombiner::BuilderTy &B) {
1898+
1899+
B.SetInsertPoint(X);
1900+
1901+
// Every instance of R1 may have different fpmath metadata and fpmath flags.
1902+
// We try to preserve them by having seperate fdiv instruction per R1
1903+
// instance.
1904+
Instruction *Tmp1;
1905+
for (Instruction *I : R1) {
1906+
Tmp1 = cast<Instruction>(
1907+
B.CreateFDiv(ConstantFP::get(R1[0]->getType(), 1.0), SqrtOp));
1908+
Tmp1->copyMetadata(*I);
1909+
Tmp1->copyFastMathFlags(I);
1910+
I->replaceAllUsesWith(Tmp1);
1911+
}
1912+
1913+
// Although, by value, Tmp2 = CI , every instance of R2 may have different
1914+
// fpmath metadata and fpmath flags. We try to preserve them by cloning the
1915+
// call instruction per R2 instance.
1916+
CallInst *Sqrt = B.CreateUnaryIntrinsic(Intrinsic::sqrt, SqrtOp);
1917+
Instruction *Tmp2;
1918+
for (Instruction *I : R2) {
1919+
Tmp2 = Sqrt->clone();
1920+
Tmp2->insertBefore(CI);
1921+
Tmp2->setName("sqrt");
1922+
Tmp2->copyFastMathFlags(I);
1923+
Tmp2->copyMetadata(*I);
1924+
I->replaceAllUsesWith(Tmp2);
1925+
}
1926+
1927+
Instruction *Tmp3;
1928+
// If X = -1/sqrt(a) initially,then Tmp3 = -(Tmp1*tmp2)
1929+
if (match(X, m_FDiv(m_SpecificFP(-1.0), m_Specific(CI)))) {
1930+
Value *Mul = B.CreateFMul(Tmp1, Tmp2);
1931+
Tmp3 = cast<Instruction>(B.CreateFNegFMF(Mul, X));
1932+
} else
1933+
Tmp3 = cast<Instruction>(B.CreateFMulFMF(Tmp1, Tmp2, X));
1934+
Tmp3->copyMetadata(*X);
1935+
1936+
return Tmp3;
1937+
}
1938+
17991939
Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
18001940
Module *M = I.getModule();
18011941

@@ -1820,6 +1960,26 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
18201960
return R;
18211961

18221962
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
1963+
1964+
// Convert
1965+
// x = 1.0/sqrt(a)
1966+
// r1 = x * x;
1967+
// r2 = a/sqrt(a);
1968+
//
1969+
// TO
1970+
//
1971+
// r1 = 1/a
1972+
// r2 = sqrt(a)
1973+
// x = r1 * r2
1974+
SmallVector<Instruction *, 2> R1, R2;
1975+
getFSqrtDivOptPattern(&I, R1, R2);
1976+
if (!R1.empty() && !R2.empty() && isFSqrtDivToFMulLegal(&I, R1, R2)) {
1977+
CallInst *CI = cast<CallInst>(I.getOperand(1));
1978+
Value *SqrtOp = CI->getArgOperand(0);
1979+
if (Value *D = convertFSqrtDivIntoFMul(CI, &I, R1, R2, SqrtOp, Builder))
1980+
return replaceInstUsesWith(I, D);
1981+
}
1982+
18231983
if (isa<Constant>(Op0))
18241984
if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
18251985
if (Instruction *R = FoldOpIntoSelect(I, SI))

0 commit comments

Comments
 (0)