Skip to content

Commit e468b2d

Browse files
committed
[InstCombine] Transform high latency, dependent FSQRT/FDIV into FMUL
The proposed patch, in general, tries to transform the below code sequence: x = 1.0 / sqrt (a); r1 = x * x; // same as 1.0 / a r2 = a * x; // same as sqrt (a) TO (If x, r1 and r2 are all used further in the code) tmp1 = 1.0 / a tmp2 = sqrt (a) tmp3 = tmp1 * tmp2 x = tmp3 r1 = tmp1 r2 = tmp2 The transform tries to make high latency sqrt and div operations independent and also saves on one multiplication. The patch was tested with SPEC17 suite with cpu=neoverse-v2. The performance uplift achieved was: 544.nab_r ~4% No other regressions were observed. Also, no compile time differences were observed with the patch. Closes #54652
1 parent e05c1b4 commit e468b2d

File tree

2 files changed

+639
-3
lines changed

2 files changed

+639
-3
lines changed

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp

Lines changed: 176 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,129 @@ Instruction *InstCombinerImpl::foldPowiReassoc(BinaryOperator &I) {
626626
return nullptr;
627627
}
628628

629+
static bool isFSqrtDivToFMulLegal(Instruction *X,
630+
SmallSetVector<Instruction *, 2> &R1,
631+
SmallSetVector<Instruction *, 2> &R2) {
632+
633+
BasicBlock *BBx = X->getParent();
634+
BasicBlock *BBr1 = R1[0]->getParent();
635+
BasicBlock *BBr2 = R2[0]->getParent();
636+
637+
auto IsStrictFP = [](Instruction *I) {
638+
IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
639+
return II && II->isStrictFP();
640+
};
641+
642+
// Check the constaints on instruction X.
643+
auto XConstraintsSatisfied = [X, &IsStrictFP]() {
644+
if (IsStrictFP(X))
645+
return false;
646+
// X must atleast have 4 uses.
647+
// 3 uses as part of
648+
// r1 = x * x
649+
// r2 = a * x
650+
// Now, post-transform, r1/r2 will no longer have usage of 'x' and if the
651+
// changes to 'x' need to persist, we must have one more usage of 'x'
652+
if (!X->hasNUsesOrMore(4))
653+
return false;
654+
// Check if reciprocalFP is enabled.
655+
bool RecipFPMath = dyn_cast<FPMathOperator>(X)->hasAllowReciprocal();
656+
return RecipFPMath;
657+
};
658+
if (!XConstraintsSatisfied())
659+
return false;
660+
661+
// Check the constraints on instructions in R1.
662+
auto R1ConstraintsSatisfied = [BBr1, &IsStrictFP](Instruction *I) {
663+
if (IsStrictFP(I))
664+
return false;
665+
// When you have multiple instructions residing in R1 and R2 respectively,
666+
// it's difficult to generate combinations of (R1,R2) and then check if we
667+
// have the required pattern. So, for now, just be conservative.
668+
if (I->getParent() != BBr1)
669+
return false;
670+
if (!I->hasNUsesOrMore(1))
671+
return false;
672+
// The optimization tries to convert
673+
// R1 = div * div where, div = 1/sqrt(a)
674+
// to
675+
// R1 = 1/a
676+
// Now, this simplication does not work because sqrt(a)=NaN when a<0
677+
if (!I->hasNoNaNs())
678+
return false;
679+
// sqrt(-0.0) = -0.0, and doing this simplication would change the sign of
680+
// the result.
681+
return I->hasNoSignedZeros();
682+
};
683+
if (!std::all_of(R1.begin(), R1.end(), R1ConstraintsSatisfied))
684+
return false;
685+
686+
// Check the constraints on instructions in R2.
687+
auto R2ConstraintsSatisfied = [BBr2, &IsStrictFP](Instruction *I) {
688+
if (IsStrictFP(I))
689+
return false;
690+
// When you have multiple instructions residing in R1 and R2 respectively,
691+
// it's difficult to generate combination of (R1,R2) and then check if we
692+
// have the required pattern. So, for now, just be conservative.
693+
if (I->getParent() != BBr2)
694+
return false;
695+
if (!I->hasNUsesOrMore(1))
696+
return false;
697+
// This simplication changes
698+
// R2 = a * 1/sqrt(a)
699+
// to
700+
// R2 = sqrt(a)
701+
// Now, sqrt(-0.0) = -0.0 and doing this simplication would produce -0.0
702+
// instead of NaN.
703+
return I->hasNoSignedZeros();
704+
};
705+
if (!std::all_of(R2.begin(), R2.end(), R2ConstraintsSatisfied))
706+
return false;
707+
708+
// Check the constraints on X, R1 and R2 combined.
709+
// fdiv instruction and one of the multiplications must reside in the same
710+
// block. If not, the optimized code may execute more ops than before and
711+
// this may hamper the performance.
712+
return (BBx == BBr1 || BBx == BBr2);
713+
}
714+
715+
static void getFSqrtDivOptPattern(Value *Div,
716+
SmallSetVector<Instruction *, 2> &R1,
717+
SmallSetVector<Instruction *, 2> &R2) {
718+
Value *A;
719+
if (match(Div, m_FDiv(m_FPOne(), m_Sqrt(m_Value(A)))) ||
720+
match(Div, m_FDiv(m_SpecificFP(-1.0), m_Sqrt(m_Value(A))))) {
721+
for (auto U : Div->users()) {
722+
Instruction *I = dyn_cast<Instruction>(U);
723+
if (!(I && I->getOpcode() == Instruction::FMul))
724+
continue;
725+
726+
if (match(I, m_FMul(m_Specific(Div), m_Specific(Div)))) {
727+
R1.insert(I);
728+
continue;
729+
}
730+
731+
Value *X;
732+
if (match(I, m_FMul(m_Specific(Div), m_Value(X))) && X == A) {
733+
R2.insert(I);
734+
continue;
735+
}
736+
737+
if (match(I, m_FMul(m_Value(X), m_Specific(Div))) && X == A) {
738+
R2.insert(I);
739+
continue;
740+
}
741+
}
742+
}
743+
}
744+
745+
static bool delayFMulSqrtTransform(Value *Div) {
746+
SmallSetVector<Instruction *, 2> R1, R2;
747+
getFSqrtDivOptPattern(Div, R1, R2);
748+
return (!(R1.empty() || R2.empty()) &&
749+
isFSqrtDivToFMulLegal((Instruction *)Div, R1, R2));
750+
}
751+
629752
Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
630753
Value *Op0 = I.getOperand(0);
631754
Value *Op1 = I.getOperand(1);
@@ -705,19 +828,20 @@ Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
705828
// has the necessary (reassoc) fast-math-flags.
706829
if (I.hasNoSignedZeros() &&
707830
match(Op0, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
708-
match(Y, m_Sqrt(m_Value(X))) && Op1 == X)
831+
match(Y, m_Sqrt(m_Value(X))) && Op1 == X && !delayFMulSqrtTransform(Op0))
709832
return BinaryOperator::CreateFDivFMF(X, Y, &I);
710833
if (I.hasNoSignedZeros() &&
711834
match(Op1, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
712-
match(Y, m_Sqrt(m_Value(X))) && Op0 == X)
835+
match(Y, m_Sqrt(m_Value(X))) && Op0 == X && !delayFMulSqrtTransform(Op1))
713836
return BinaryOperator::CreateFDivFMF(X, Y, &I);
714837

715838
// Like the similar transform in instsimplify, this requires 'nsz' because
716839
// sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0.
717840
if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 && Op0->hasNUses(2)) {
718841
// Peek through fdiv to find squaring of square root:
719842
// (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y
720-
if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y))))) {
843+
if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y)))) &&
844+
!delayFMulSqrtTransform(Op0)) {
721845
Value *XX = Builder.CreateFMulFMF(X, X, &I);
722846
return BinaryOperator::CreateFDivFMF(XX, Y, &I);
723847
}
@@ -1796,6 +1920,35 @@ static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
17961920
return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I);
17971921
}
17981922

1923+
Value *convertFSqrtDivIntoFMul(CallInst *CI, Instruction *X,
1924+
SmallSetVector<Instruction *, 2> &R1,
1925+
SmallSetVector<Instruction *, 2> &R2,
1926+
Value *SqrtOp, InstCombiner::BuilderTy &B) {
1927+
1928+
// 1. synthesize tmp1 = 1/a and replace uses of r1
1929+
B.SetInsertPoint(X);
1930+
Value *Tmp1 =
1931+
B.CreateFDivFMF(ConstantFP::get(R1[0]->getType(), 1.0), SqrtOp, R1[0]);
1932+
for (auto *I : R1)
1933+
I->replaceAllUsesWith(Tmp1);
1934+
1935+
// 2. No need of synthesizing Tmp2 again. In this scenario, tmp2 = CI. Replace
1936+
// uses of r2 with tmp2
1937+
for (auto *I : R2)
1938+
I->replaceAllUsesWith(CI);
1939+
1940+
// 3. synthesize tmp3 = tmp1 * tmp2 . Replace uses of 'x' with tmp3
1941+
Value *Tmp3;
1942+
// If x = -1/sqrt(a) initially,then Tmp3 = -(Tmp1*tmp2)
1943+
if (match(X, m_FDiv(m_SpecificFP(-1.0), m_Specific(CI)))) {
1944+
Value *Mul = B.CreateFMul(Tmp1, CI);
1945+
Tmp3 = B.CreateFNegFMF(Mul, X);
1946+
} else
1947+
Tmp3 = B.CreateFMulFMF(Tmp1, CI, X);
1948+
1949+
return Tmp3;
1950+
}
1951+
17991952
Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
18001953
Module *M = I.getModule();
18011954

@@ -1820,6 +1973,26 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
18201973
return R;
18211974

18221975
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
1976+
1977+
// Convert
1978+
// x = 1.0/sqrt(a)
1979+
// r1 = x * x;
1980+
// r2 = a * x;
1981+
//
1982+
// TO
1983+
//
1984+
// r1 = 1/a
1985+
// r2 = sqrt(a)
1986+
// x = r1 * r2
1987+
SmallSetVector<Instruction *, 2> R1, R2;
1988+
getFSqrtDivOptPattern(&I, R1, R2);
1989+
if (!(R1.empty() || R2.empty()) && isFSqrtDivToFMulLegal(&I, R1, R2)) {
1990+
CallInst *CI = (CallInst *)((&I)->getOperand(1));
1991+
Value *SqrtOp = CI->getArgOperand(0);
1992+
if (Value *D = convertFSqrtDivIntoFMul(CI, &I, R1, R2, SqrtOp, Builder))
1993+
return replaceInstUsesWith(I, D);
1994+
}
1995+
18231996
if (isa<Constant>(Op0))
18241997
if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
18251998
if (Instruction *R = FoldOpIntoSelect(I, SI))

0 commit comments

Comments
 (0)