@@ -51,6 +51,8 @@ class AMDGPULibCalls {
51
51
52
52
const TargetMachine *TM;
53
53
54
+ bool UnsafeFPMath = false ;
55
+
54
56
// -fuse-native.
55
57
bool AllNative = false ;
56
58
@@ -73,10 +75,10 @@ class AMDGPULibCalls {
73
75
bool fold_divide (CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
74
76
75
77
// pow/powr/pown
76
- bool fold_pow (CallInst *CI , IRBuilder<> &B, const FuncInfo &FInfo);
78
+ bool fold_pow (FPMathOperator *FPOp , IRBuilder<> &B, const FuncInfo &FInfo);
77
79
78
80
// rootn
79
- bool fold_rootn (CallInst *CI , IRBuilder<> &B, const FuncInfo &FInfo);
81
+ bool fold_rootn (FPMathOperator *FPOp , IRBuilder<> &B, const FuncInfo &FInfo);
80
82
81
83
// fma/mad
82
84
bool fold_fma_mad (CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
@@ -90,10 +92,10 @@ class AMDGPULibCalls {
90
92
bool evaluateCall (CallInst *aCI, const FuncInfo &FInfo);
91
93
92
94
// sqrt
93
- bool fold_sqrt (CallInst *CI , IRBuilder<> &B, const FuncInfo &FInfo);
95
+ bool fold_sqrt (FPMathOperator *FPOp , IRBuilder<> &B, const FuncInfo &FInfo);
94
96
95
97
// sin/cos
96
- bool fold_sincos (CallInst *CI , IRBuilder<> &B, const FuncInfo &FInfo,
98
+ bool fold_sincos (FPMathOperator *FPOp , IRBuilder<> &B, const FuncInfo &FInfo,
97
99
AliasAnalysis *AA);
98
100
99
101
// __read_pipe/__write_pipe
@@ -113,7 +115,9 @@ class AMDGPULibCalls {
113
115
protected:
114
116
CallInst *CI;
115
117
116
- bool isUnsafeMath (const CallInst *CI) const ;
118
+ bool isUnsafeMath (const FPMathOperator *FPOp) const ;
119
+
120
+ bool canIncreasePrecisionOfConstantFold (const FPMathOperator *FPOp) const ;
117
121
118
122
void replaceCall (Value *With) {
119
123
CI->replaceAllUsesWith (With);
@@ -125,6 +129,7 @@ class AMDGPULibCalls {
125
129
126
130
bool fold (CallInst *CI, AliasAnalysis *AA = nullptr );
127
131
132
+ void initFunction (const Function &F);
128
133
void initNativeFuncs ();
129
134
130
135
// Replace a normal math function call with that native version
@@ -445,13 +450,18 @@ bool AMDGPULibCalls::parseFunctionName(const StringRef &FMangledName,
445
450
return AMDGPULibFunc::parse (FMangledName, FInfo);
446
451
}
447
452
448
- bool AMDGPULibCalls::isUnsafeMath (const CallInst *CI) const {
449
- if (auto Op = dyn_cast<FPMathOperator>(CI))
450
- if (Op->isFast ())
451
- return true ;
452
- const Function *F = CI->getParent ()->getParent ();
453
- Attribute Attr = F->getFnAttribute (" unsafe-fp-math" );
454
- return Attr.getValueAsBool ();
453
+ bool AMDGPULibCalls::isUnsafeMath (const FPMathOperator *FPOp) const {
454
+ return UnsafeFPMath || FPOp->isFast ();
455
+ }
456
+
457
+ bool AMDGPULibCalls::canIncreasePrecisionOfConstantFold (
458
+ const FPMathOperator *FPOp) const {
459
+ // TODO: Refine to approxFunc or contract
460
+ return isUnsafeMath (FPOp);
461
+ }
462
+
463
+ void AMDGPULibCalls::initFunction (const Function &F) {
464
+ UnsafeFPMath = F.getFnAttribute (" unsafe-fp-math" ).getValueAsBool ();
455
465
}
456
466
457
467
bool AMDGPULibCalls::useNativeFunc (const StringRef F) const {
@@ -620,65 +630,61 @@ bool AMDGPULibCalls::fold(CallInst *CI, AliasAnalysis *AA) {
620
630
if (TDOFold (CI, FInfo))
621
631
return true ;
622
632
623
- // Under unsafe-math, evaluate calls if possible.
624
- // According to Brian Sumner, we can do this for all f32 function calls
625
- // using host's double function calls.
626
- if (isUnsafeMath (CI) && evaluateCall (CI, FInfo))
627
- return true ;
633
+ if (FPMathOperator *FPOp = dyn_cast<FPMathOperator>(CI)) {
634
+ // Under unsafe-math, evaluate calls if possible.
635
+ // According to Brian Sumner, we can do this for all f32 function calls
636
+ // using host's double function calls.
637
+ if (canIncreasePrecisionOfConstantFold (FPOp) && evaluateCall (CI, FInfo))
638
+ return true ;
628
639
629
- // Copy fast flags from the original call.
630
- if (const FPMathOperator *FPOp = dyn_cast<const FPMathOperator>(CI))
640
+ // Copy fast flags from the original call.
631
641
B.setFastMathFlags (FPOp->getFastMathFlags ());
632
642
633
- // Specialized optimizations for each function call
634
- switch (FInfo.getId ()) {
635
- case AMDGPULibFunc::EI_RECIP:
636
- // skip vector function
637
- assert ((FInfo.getPrefix () == AMDGPULibFunc::NATIVE ||
638
- FInfo.getPrefix () == AMDGPULibFunc::HALF) &&
639
- " recip must be an either native or half function" );
640
- return (getVecSize (FInfo) != 1 ) ? false : fold_recip (CI, B, FInfo);
641
-
642
- case AMDGPULibFunc::EI_DIVIDE:
643
- // skip vector function
644
- assert ((FInfo.getPrefix () == AMDGPULibFunc::NATIVE ||
645
- FInfo.getPrefix () == AMDGPULibFunc::HALF) &&
646
- " divide must be an either native or half function" );
647
- return (getVecSize (FInfo) != 1 ) ? false : fold_divide (CI, B, FInfo);
648
-
649
- case AMDGPULibFunc::EI_POW:
650
- case AMDGPULibFunc::EI_POWR:
651
- case AMDGPULibFunc::EI_POWN:
652
- return fold_pow (CI, B, FInfo);
653
-
654
- case AMDGPULibFunc::EI_ROOTN:
655
- // skip vector function
656
- return (getVecSize (FInfo) != 1 ) ? false : fold_rootn (CI, B, FInfo);
657
-
658
- case AMDGPULibFunc::EI_FMA:
659
- case AMDGPULibFunc::EI_MAD:
660
- case AMDGPULibFunc::EI_NFMA:
661
- // skip vector function
662
- return (getVecSize (FInfo) != 1 ) ? false : fold_fma_mad (CI, B, FInfo);
663
-
664
- case AMDGPULibFunc::EI_SQRT:
665
- return isUnsafeMath (CI) && fold_sqrt (CI, B, FInfo);
666
- case AMDGPULibFunc::EI_COS:
667
- case AMDGPULibFunc::EI_SIN:
668
- if ((getArgType (FInfo) == AMDGPULibFunc::F32 ||
669
- getArgType (FInfo) == AMDGPULibFunc::F64)
670
- && (FInfo.getPrefix () == AMDGPULibFunc::NOPFX))
671
- return fold_sincos (CI, B, FInfo, AA);
672
-
673
- break ;
674
- case AMDGPULibFunc::EI_READ_PIPE_2:
675
- case AMDGPULibFunc::EI_READ_PIPE_4:
676
- case AMDGPULibFunc::EI_WRITE_PIPE_2:
677
- case AMDGPULibFunc::EI_WRITE_PIPE_4:
678
- return fold_read_write_pipe (CI, B, FInfo);
679
-
680
- default :
681
- break ;
643
+ // Specialized optimizations for each function call
644
+ switch (FInfo.getId ()) {
645
+ case AMDGPULibFunc::EI_POW:
646
+ case AMDGPULibFunc::EI_POWR:
647
+ case AMDGPULibFunc::EI_POWN:
648
+ return fold_pow (FPOp, B, FInfo);
649
+ case AMDGPULibFunc::EI_ROOTN:
650
+ return fold_rootn (FPOp, B, FInfo);
651
+ case AMDGPULibFunc::EI_SQRT:
652
+ return fold_sqrt (FPOp, B, FInfo);
653
+ case AMDGPULibFunc::EI_COS:
654
+ case AMDGPULibFunc::EI_SIN:
655
+ return fold_sincos (FPOp, B, FInfo, AA);
656
+ case AMDGPULibFunc::EI_RECIP:
657
+ // skip vector function
658
+ assert ((FInfo.getPrefix () == AMDGPULibFunc::NATIVE ||
659
+ FInfo.getPrefix () == AMDGPULibFunc::HALF) &&
660
+ " recip must be an either native or half function" );
661
+ return (getVecSize (FInfo) != 1 ) ? false : fold_recip (CI, B, FInfo);
662
+
663
+ case AMDGPULibFunc::EI_DIVIDE:
664
+ // skip vector function
665
+ assert ((FInfo.getPrefix () == AMDGPULibFunc::NATIVE ||
666
+ FInfo.getPrefix () == AMDGPULibFunc::HALF) &&
667
+ " divide must be an either native or half function" );
668
+ return (getVecSize (FInfo) != 1 ) ? false : fold_divide (CI, B, FInfo);
669
+ case AMDGPULibFunc::EI_FMA:
670
+ case AMDGPULibFunc::EI_MAD:
671
+ case AMDGPULibFunc::EI_NFMA:
672
+ // skip vector function
673
+ return (getVecSize (FInfo) != 1 ) ? false : fold_fma_mad (CI, B, FInfo);
674
+ default :
675
+ break ;
676
+ }
677
+ } else {
678
+ // Specialized optimizations for each function call
679
+ switch (FInfo.getId ()) {
680
+ case AMDGPULibFunc::EI_READ_PIPE_2:
681
+ case AMDGPULibFunc::EI_READ_PIPE_4:
682
+ case AMDGPULibFunc::EI_WRITE_PIPE_2:
683
+ case AMDGPULibFunc::EI_WRITE_PIPE_4:
684
+ return fold_read_write_pipe (CI, B, FInfo);
685
+ default :
686
+ break ;
687
+ }
682
688
}
683
689
684
690
return false ;
@@ -796,7 +802,7 @@ static double log2(double V) {
796
802
}
797
803
}
798
804
799
- bool AMDGPULibCalls::fold_pow (CallInst *CI , IRBuilder<> &B,
805
+ bool AMDGPULibCalls::fold_pow (FPMathOperator *FPOp , IRBuilder<> &B,
800
806
const FuncInfo &FInfo) {
801
807
assert ((FInfo.getId () == AMDGPULibFunc::EI_POW ||
802
808
FInfo.getId () == AMDGPULibFunc::EI_POWR ||
@@ -827,7 +833,7 @@ bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B,
827
833
}
828
834
829
835
// No unsafe math , no constant argument, do nothing
830
- if (!isUnsafeMath (CI ) && !CF && !CINT && !CZero)
836
+ if (!isUnsafeMath (FPOp ) && !CF && !CINT && !CZero)
831
837
return false ;
832
838
833
839
// 0x1111111 means that we don't do anything for this call.
@@ -885,7 +891,7 @@ bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B,
885
891
}
886
892
}
887
893
888
- if (!isUnsafeMath (CI ))
894
+ if (!isUnsafeMath (FPOp ))
889
895
return false ;
890
896
891
897
// Unsafe Math optimization
@@ -1079,10 +1085,14 @@ bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B,
1079
1085
return true ;
1080
1086
}
1081
1087
1082
- bool AMDGPULibCalls::fold_rootn (CallInst *CI , IRBuilder<> &B,
1088
+ bool AMDGPULibCalls::fold_rootn (FPMathOperator *FPOp , IRBuilder<> &B,
1083
1089
const FuncInfo &FInfo) {
1084
- Value *opr0 = CI->getArgOperand (0 );
1085
- Value *opr1 = CI->getArgOperand (1 );
1090
+ // skip vector function
1091
+ if (getVecSize (FInfo) != 1 )
1092
+ return false ;
1093
+
1094
+ Value *opr0 = FPOp->getOperand (0 );
1095
+ Value *opr1 = FPOp->getOperand (1 );
1086
1096
1087
1097
ConstantInt *CINT = dyn_cast<ConstantInt>(opr1);
1088
1098
if (!CINT) {
@@ -1188,8 +1198,11 @@ FunctionCallee AMDGPULibCalls::getNativeFunction(Module *M,
1188
1198
}
1189
1199
1190
1200
// fold sqrt -> native_sqrt (x)
1191
- bool AMDGPULibCalls::fold_sqrt (CallInst *CI , IRBuilder<> &B,
1201
+ bool AMDGPULibCalls::fold_sqrt (FPMathOperator *FPOp , IRBuilder<> &B,
1192
1202
const FuncInfo &FInfo) {
1203
+ if (!isUnsafeMath (FPOp))
1204
+ return false ;
1205
+
1193
1206
if (getArgType (FInfo) == AMDGPULibFunc::F32 && (getVecSize (FInfo) == 1 ) &&
1194
1207
(FInfo.getPrefix () != AMDGPULibFunc::NATIVE)) {
1195
1208
if (FunctionCallee FPExpr = getNativeFunction (
@@ -1206,10 +1219,16 @@ bool AMDGPULibCalls::fold_sqrt(CallInst *CI, IRBuilder<> &B,
1206
1219
}
1207
1220
1208
1221
// fold sin, cos -> sincos.
1209
- bool AMDGPULibCalls::fold_sincos (CallInst *CI , IRBuilder<> &B,
1222
+ bool AMDGPULibCalls::fold_sincos (FPMathOperator *FPOp , IRBuilder<> &B,
1210
1223
const FuncInfo &fInfo , AliasAnalysis *AA) {
1211
1224
assert (fInfo .getId () == AMDGPULibFunc::EI_SIN ||
1212
1225
fInfo .getId () == AMDGPULibFunc::EI_COS);
1226
+
1227
+ if ((getArgType (fInfo ) != AMDGPULibFunc::F32 &&
1228
+ getArgType (fInfo ) != AMDGPULibFunc::F64) ||
1229
+ fInfo .getPrefix () != AMDGPULibFunc::NOPFX)
1230
+ return false ;
1231
+
1213
1232
bool const isSin = fInfo .getId () == AMDGPULibFunc::EI_SIN;
1214
1233
1215
1234
Value *CArgVal = CI->getArgOperand (0 );
@@ -1651,6 +1670,8 @@ bool AMDGPUSimplifyLibCalls::runOnFunction(Function &F) {
1651
1670
if (skipFunction (F))
1652
1671
return false ;
1653
1672
1673
+ Simplifier.initFunction (F);
1674
+
1654
1675
bool Changed = false ;
1655
1676
auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults ();
1656
1677
@@ -1675,6 +1696,7 @@ PreservedAnalyses AMDGPUSimplifyLibCallsPass::run(Function &F,
1675
1696
FunctionAnalysisManager &AM) {
1676
1697
AMDGPULibCalls Simplifier (&TM);
1677
1698
Simplifier.initNativeFuncs ();
1699
+ Simplifier.initFunction (F);
1678
1700
1679
1701
bool Changed = false ;
1680
1702
auto AA = &AM.getResult <AAManager>(F);
@@ -1701,6 +1723,8 @@ bool AMDGPUUseNativeCalls::runOnFunction(Function &F) {
1701
1723
if (skipFunction (F) || UseNative.empty ())
1702
1724
return false ;
1703
1725
1726
+ Simplifier.initFunction (F);
1727
+
1704
1728
bool Changed = false ;
1705
1729
for (auto &BB : F) {
1706
1730
for (BasicBlock::iterator I = BB.begin (), E = BB.end (); I != E; ) {
@@ -1721,6 +1745,7 @@ PreservedAnalyses AMDGPUUseNativeCallsPass::run(Function &F,
1721
1745
1722
1746
AMDGPULibCalls Simplifier;
1723
1747
Simplifier.initNativeFuncs ();
1748
+ Simplifier.initFunction (F);
1724
1749
1725
1750
bool Changed = false ;
1726
1751
for (auto &BB : F) {
0 commit comments