Skip to content

Commit 51b1bef

Browse files
azabaznoigcbot
authored andcommitted
Generate native sqrt for fast llvm sqrt operation and match reciprocal sqrt
1 parent 3930b96 commit 51b1bef

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

IGC/VectorCompiler/lib/GenXCodeGen/GenXLowering.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2494,8 +2494,14 @@ bool GenXLowering::lowerSqrt(CallInst *CI) {
24942494
IGC_ASSERT_MESSAGE(GenXIntrinsic::getAnyIntrinsicID(CI) == Intrinsic::sqrt,
24952495
"llvm.sqrt expected");
24962496
auto *ResTy = CI->getType();
2497-
auto *SqrtDecl = GenXIntrinsic::getGenXDeclaration(
2498-
CI->getModule(), GenXIntrinsic::genx_ieee_sqrt, {ResTy, ResTy});
2497+
2498+
GenXIntrinsic::ID SqrtID = (CI->getType()->getScalarType()->isDoubleTy() ||
2499+
!CI->getFastMathFlags().isFast())
2500+
? GenXIntrinsic::genx_ieee_sqrt
2501+
: GenXIntrinsic::genx_sqrt;
2502+
2503+
auto *SqrtDecl =
2504+
GenXIntrinsic::getGenXDeclaration(CI->getModule(), SqrtID, {ResTy});
24992505
Value *Result = IRBuilder<>(CI).CreateCall(SqrtDecl, {CI->getArgOperand(0)},
25002506
CI->getName());
25012507
CI->replaceAllUsesWith(Result);

IGC/VectorCompiler/lib/GenXCodeGen/GenXPatternMatch.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ class GenXPatternMatch : public FunctionPass,
166166
// flipBoolNot : flip a (vector) bool not instruction if beneficial
167167
bool flipBoolNot(Instruction *Inst);
168168
// foldBoolAnd : fold a (vector) bool and into sel/wrregion if beneficial
169+
bool matchInverseSqrt(Instruction *I);
169170
bool foldBoolAnd(Instruction *Inst);
170171
bool simplifyPredRegion(CallInst *Inst);
171172
bool simplifyWrRegion(CallInst *Inst);
@@ -398,6 +399,9 @@ void GenXPatternMatch::visitCallInst(CallInst &I) {
398399
switch (unsigned ID = GenXIntrinsic::getGenXIntrinsicID(&I)) {
399400
default:
400401
break;
402+
case GenXIntrinsic::genx_inv:
403+
Changed |= matchInverseSqrt(&I);
404+
break;
401405
case GenXIntrinsic::genx_ssadd_sat:
402406
case GenXIntrinsic::genx_suadd_sat:
403407
case GenXIntrinsic::genx_usadd_sat:
@@ -768,6 +772,42 @@ bool GenXPatternMatch::flipBoolNot(Instruction *Inst) {
768772
return true;
769773
}
770774

775+
/// (inv (sqrt x)) -> (rsqrt x)
776+
bool GenXPatternMatch::matchInverseSqrt(Instruction *I) {
777+
IGC_ASSERT(I);
778+
779+
auto *OpInst = dyn_cast<CallInst>(I->getOperand(0));
780+
if (!OpInst)
781+
return false;
782+
783+
// Leave as it is for double types
784+
if (OpInst->getType()->getScalarType()->isDoubleTy())
785+
return false;
786+
787+
// Generate inverse sqrt only if fast flag for llvm intrinsic is used or
788+
// genx sqrt intrinsics is specified
789+
auto IID = GenXIntrinsic::getAnyIntrinsicID(OpInst);
790+
if (!(IID == GenXIntrinsic::genx_sqrt ||
791+
(IID == Intrinsic::sqrt && OpInst->getFastMathFlags().isFast())))
792+
return false;
793+
794+
// Leave as it if sqrt has multiple uses:
795+
// generating rsqrt operation is not beneficial
796+
if (OpInst->getNumUses() > 1)
797+
return false;
798+
799+
Function *Decl = GenXIntrinsic::getGenXDeclaration(
800+
I->getModule(), GenXIntrinsic::genx_rsqrt, {I->getType()});
801+
auto NewInst =
802+
CallInst::Create(Decl, {OpInst->getOperand(0)},
803+
OpInst->getName() + "inversed", I->getNextNode());
804+
I->replaceAllUsesWith(NewInst);
805+
I->eraseFromParent();
806+
807+
OpInst->eraseFromParent();
808+
return true;
809+
}
810+
771811
/***********************************************************************
772812
* foldBoolAnd : fold a (vector) bool and into sel/wrregion if beneficial
773813
*

0 commit comments

Comments
 (0)