13
13
14
14
#include " InstCombineInternal.h"
15
15
#include " llvm/ADT/APInt.h"
16
+ #include " llvm/ADT/SmallPtrSet.h"
16
17
#include " llvm/ADT/SmallVector.h"
17
18
#include " llvm/Analysis/InstructionSimplify.h"
18
19
#include " llvm/Analysis/ValueTracking.h"
@@ -666,6 +667,94 @@ Instruction *InstCombinerImpl::foldPowiReassoc(BinaryOperator &I) {
666
667
return nullptr ;
667
668
}
668
669
670
+ // If we have the following pattern,
671
+ // X = 1.0/sqrt(a)
672
+ // R1 = X * X
673
+ // R2 = a/sqrt(a)
674
+ // then this method collects all the instructions that match R1 and R2.
675
+ static bool getFSqrtDivOptPattern (Instruction *Div,
676
+ SmallPtrSetImpl<Instruction *> &R1,
677
+ SmallPtrSetImpl<Instruction *> &R2) {
678
+ Value *A;
679
+ if (match (Div, m_FDiv (m_FPOne (), m_Sqrt (m_Value (A)))) ||
680
+ match (Div, m_FDiv (m_SpecificFP (-1.0 ), m_Sqrt (m_Value (A))))) {
681
+ for (User *U : Div->users ()) {
682
+ Instruction *I = cast<Instruction>(U);
683
+ if (match (I, m_FMul (m_Specific (Div), m_Specific (Div))))
684
+ R1.insert (I);
685
+ }
686
+
687
+ CallInst *CI = cast<CallInst>(Div->getOperand (1 ));
688
+ for (User *U : CI->users ()) {
689
+ Instruction *I = cast<Instruction>(U);
690
+ if (match (I, m_FDiv (m_Specific (A), m_Sqrt (m_Specific (A)))))
691
+ R2.insert (I);
692
+ }
693
+ }
694
+ return !R1.empty () && !R2.empty ();
695
+ }
696
+
697
+ // Check legality for transforming
698
+ // x = 1.0/sqrt(a)
699
+ // r1 = x * x;
700
+ // r2 = a/sqrt(a);
701
+ //
702
+ // TO
703
+ //
704
+ // r1 = 1/a
705
+ // r2 = sqrt(a)
706
+ // x = r1 * r2
707
+ // This transform works only when 'a' is known positive.
708
+ static bool isFSqrtDivToFMulLegal (Instruction *X,
709
+ SmallPtrSetImpl<Instruction *> &R1,
710
+ SmallPtrSetImpl<Instruction *> &R2) {
711
+ // Check if the required pattern for the transformation exists.
712
+ if (!getFSqrtDivOptPattern (X, R1, R2))
713
+ return false ;
714
+
715
+ BasicBlock *BBx = X->getParent ();
716
+ BasicBlock *BBr1 = (*R1.begin ())->getParent ();
717
+ BasicBlock *BBr2 = (*R2.begin ())->getParent ();
718
+
719
+ CallInst *FSqrt = cast<CallInst>(X->getOperand (1 ));
720
+ if (!FSqrt->hasAllowReassoc () || !FSqrt->hasNoNaNs () ||
721
+ !FSqrt->hasNoSignedZeros () || !FSqrt->hasNoInfs ())
722
+ return false ;
723
+
724
+ // We change x = 1/sqrt(a) to x = sqrt(a) * 1/a . This change isn't allowed
725
+ // by recip fp as it is strictly meant to transform ops of type a/b to
726
+ // a * 1/b. So, this can be considered as algebraic rewrite and reassoc flag
727
+ // has been used(rather abused)in the past for algebraic rewrites.
728
+ if (!X->hasAllowReassoc () || !X->hasAllowReciprocal () || !X->hasNoInfs ())
729
+ return false ;
730
+
731
+ // Check the constraints on X, R1 and R2 combined.
732
+ // fdiv instruction and one of the multiplications must reside in the same
733
+ // block. If not, the optimized code may execute more ops than before and
734
+ // this may hamper the performance.
735
+ if (BBx != BBr1 && BBx != BBr2)
736
+ return false ;
737
+
738
+ // Check the constraints on instructions in R1.
739
+ if (any_of (R1, [BBr1](Instruction *I) {
740
+ // When you have multiple instructions residing in R1 and R2
741
+ // respectively, it's difficult to generate combinations of (R1,R2) and
742
+ // then check if we have the required pattern. So, for now, just be
743
+ // conservative.
744
+ return (I->getParent () != BBr1 || !I->hasAllowReassoc ());
745
+ }))
746
+ return false ;
747
+
748
+ // Check the constraints on instructions in R2.
749
+ return all_of (R2, [BBr2](Instruction *I) {
750
+ // When you have multiple instructions residing in R1 and R2
751
+ // respectively, it's difficult to generate combination of (R1,R2) and
752
+ // then check if we have the required pattern. So, for now, just be
753
+ // conservative.
754
+ return (I->getParent () == BBr2 && I->hasAllowReassoc ());
755
+ });
756
+ }
757
+
669
758
Instruction *InstCombinerImpl::foldFMulReassoc (BinaryOperator &I) {
670
759
Value *Op0 = I.getOperand (0 );
671
760
Value *Op1 = I.getOperand (1 );
@@ -1917,6 +2006,75 @@ static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
1917
2006
return BinaryOperator::CreateFMulFMF (Op0, NewSqrt, &I);
1918
2007
}
1919
2008
2009
+ // Change
2010
+ // X = 1/sqrt(a)
2011
+ // R1 = X * X
2012
+ // R2 = a * X
2013
+ //
2014
+ // TO
2015
+ //
2016
+ // FDiv = 1/a
2017
+ // FSqrt = sqrt(a)
2018
+ // FMul = FDiv * FSqrt
2019
+ // Replace Uses Of R1 With FDiv
2020
+ // Replace Uses Of R2 With FSqrt
2021
+ // Replace Uses Of X With FMul
2022
+ static Instruction *
2023
+ convertFSqrtDivIntoFMul (CallInst *CI, Instruction *X,
2024
+ const SmallPtrSetImpl<Instruction *> &R1,
2025
+ const SmallPtrSetImpl<Instruction *> &R2,
2026
+ InstCombiner::BuilderTy &B, InstCombinerImpl *IC) {
2027
+
2028
+ B.SetInsertPoint (X);
2029
+
2030
+ // Have an instruction that is representative of all of instructions in R1 and
2031
+ // get the most common fpmath metadata and fast-math flags on it.
2032
+ Value *SqrtOp = CI->getArgOperand (0 );
2033
+ auto *FDiv = cast<Instruction>(
2034
+ B.CreateFDiv (ConstantFP::get (X->getType (), 1.0 ), SqrtOp));
2035
+ auto *R1FPMathMDNode = (*R1.begin ())->getMetadata (LLVMContext::MD_fpmath);
2036
+ FastMathFlags R1FMF = (*R1.begin ())->getFastMathFlags (); // Common FMF
2037
+ for (Instruction *I : R1) {
2038
+ R1FPMathMDNode = MDNode::getMostGenericFPMath (
2039
+ R1FPMathMDNode, I->getMetadata (LLVMContext::MD_fpmath));
2040
+ R1FMF &= I->getFastMathFlags ();
2041
+ IC->replaceInstUsesWith (*I, FDiv);
2042
+ IC->eraseInstFromFunction (*I);
2043
+ }
2044
+ FDiv->setMetadata (LLVMContext::MD_fpmath, R1FPMathMDNode);
2045
+ FDiv->copyFastMathFlags (R1FMF);
2046
+
2047
+ // Have a single sqrt call instruction that is representative of all of
2048
+ // instructions in R2 and get the most common fpmath metadata and fast-math
2049
+ // flags on it.
2050
+ auto *FSqrt = cast<CallInst>(CI->clone ());
2051
+ FSqrt->insertBefore (CI);
2052
+ auto *R2FPMathMDNode = (*R2.begin ())->getMetadata (LLVMContext::MD_fpmath);
2053
+ FastMathFlags R2FMF = (*R2.begin ())->getFastMathFlags (); // Common FMF
2054
+ for (Instruction *I : R2) {
2055
+ R2FPMathMDNode = MDNode::getMostGenericFPMath (
2056
+ R2FPMathMDNode, I->getMetadata (LLVMContext::MD_fpmath));
2057
+ R2FMF &= I->getFastMathFlags ();
2058
+ IC->replaceInstUsesWith (*I, FSqrt);
2059
+ IC->eraseInstFromFunction (*I);
2060
+ }
2061
+ FSqrt->setMetadata (LLVMContext::MD_fpmath, R2FPMathMDNode);
2062
+ FSqrt->copyFastMathFlags (R2FMF);
2063
+
2064
+ Instruction *FMul;
2065
+ // If X = -1/sqrt(a) initially,then FMul = -(FDiv * FSqrt)
2066
+ if (match (X, m_FDiv (m_SpecificFP (-1.0 ), m_Specific (CI)))) {
2067
+ Value *Mul = B.CreateFMul (FDiv, FSqrt);
2068
+ FMul = cast<Instruction>(B.CreateFNeg (Mul));
2069
+ } else
2070
+ FMul = cast<Instruction>(B.CreateFMul (FDiv, FSqrt));
2071
+ FMul->copyMetadata (*X);
2072
+ FMul->copyFastMathFlags (FastMathFlags::intersectRewrite (R1FMF, R2FMF) |
2073
+ FastMathFlags::unionValue (R1FMF, R2FMF));
2074
+ IC->replaceInstUsesWith (*X, FMul);
2075
+ return IC->eraseInstFromFunction (*X);
2076
+ }
2077
+
1920
2078
Instruction *InstCombinerImpl::visitFDiv (BinaryOperator &I) {
1921
2079
Module *M = I.getModule ();
1922
2080
@@ -1941,6 +2099,24 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
1941
2099
return R;
1942
2100
1943
2101
Value *Op0 = I.getOperand (0 ), *Op1 = I.getOperand (1 );
2102
+
2103
+ // Convert
2104
+ // x = 1.0/sqrt(a)
2105
+ // r1 = x * x;
2106
+ // r2 = a/sqrt(a);
2107
+ //
2108
+ // TO
2109
+ //
2110
+ // r1 = 1/a
2111
+ // r2 = sqrt(a)
2112
+ // x = r1 * r2
2113
+ SmallPtrSet<Instruction *, 2 > R1, R2;
2114
+ if (isFSqrtDivToFMulLegal (&I, R1, R2)) {
2115
+ CallInst *CI = cast<CallInst>(I.getOperand (1 ));
2116
+ if (Instruction *D = convertFSqrtDivIntoFMul (CI, &I, R1, R2, Builder, this ))
2117
+ return D;
2118
+ }
2119
+
1944
2120
if (isa<Constant>(Op0))
1945
2121
if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
1946
2122
if (Instruction *R = FoldOpIntoSelect (I, SI))
0 commit comments