@@ -166,6 +166,7 @@ class GenXPatternMatch : public FunctionPass,
166
166
// flipBoolNot : flip a (vector) bool not instruction if beneficial
167
167
bool flipBoolNot (Instruction *Inst);
168
168
// foldBoolAnd : fold a (vector) bool and into sel/wrregion if beneficial
169
+ bool matchInverseSqrt (Instruction *I);
169
170
bool foldBoolAnd (Instruction *Inst);
170
171
bool simplifyPredRegion (CallInst *Inst);
171
172
bool simplifyWrRegion (CallInst *Inst);
@@ -398,6 +399,9 @@ void GenXPatternMatch::visitCallInst(CallInst &I) {
398
399
switch (unsigned ID = GenXIntrinsic::getGenXIntrinsicID (&I)) {
399
400
default :
400
401
break ;
402
+ case GenXIntrinsic::genx_inv:
403
+ Changed |= matchInverseSqrt (&I);
404
+ break ;
401
405
case GenXIntrinsic::genx_ssadd_sat:
402
406
case GenXIntrinsic::genx_suadd_sat:
403
407
case GenXIntrinsic::genx_usadd_sat:
@@ -768,6 +772,42 @@ bool GenXPatternMatch::flipBoolNot(Instruction *Inst) {
768
772
return true ;
769
773
}
770
774
775
+ // / (inv (sqrt x)) -> (rsqrt x)
776
+ bool GenXPatternMatch::matchInverseSqrt (Instruction *I) {
777
+ IGC_ASSERT (I);
778
+
779
+ auto *OpInst = dyn_cast<CallInst>(I->getOperand (0 ));
780
+ if (!OpInst)
781
+ return false ;
782
+
783
+ // Leave as it is for double types
784
+ if (OpInst->getType ()->getScalarType ()->isDoubleTy ())
785
+ return false ;
786
+
787
+ // Generate inverse sqrt only if fast flag for llvm intrinsic is used or
788
+ // genx sqrt intrinsics is specified
789
+ auto IID = GenXIntrinsic::getAnyIntrinsicID (OpInst);
790
+ if (!(IID == GenXIntrinsic::genx_sqrt ||
791
+ (IID == Intrinsic::sqrt && OpInst->getFastMathFlags ().isFast ())))
792
+ return false ;
793
+
794
+ // Leave as it if sqrt has multiple uses:
795
+ // generating rsqrt operation is not beneficial
796
+ if (OpInst->getNumUses () > 1 )
797
+ return false ;
798
+
799
+ Function *Decl = GenXIntrinsic::getGenXDeclaration (
800
+ I->getModule (), GenXIntrinsic::genx_rsqrt, {I->getType ()});
801
+ auto NewInst =
802
+ CallInst::Create (Decl, {OpInst->getOperand (0 )},
803
+ OpInst->getName () + " inversed" , I->getNextNode ());
804
+ I->replaceAllUsesWith (NewInst);
805
+ I->eraseFromParent ();
806
+
807
+ OpInst->eraseFromParent ();
808
+ return true ;
809
+ }
810
+
771
811
/* **********************************************************************
772
812
* foldBoolAnd : fold a (vector) bool and into sel/wrregion if beneficial
773
813
*
0 commit comments