@@ -113,15 +113,17 @@ class AMDGPULibCalls {
113
113
FunctionCallee getNativeFunction (Module *M, const FuncInfo &FInfo);
114
114
115
115
protected:
116
- CallInst *CI;
117
-
118
116
bool isUnsafeMath (const FPMathOperator *FPOp) const ;
119
117
120
118
bool canIncreasePrecisionOfConstantFold (const FPMathOperator *FPOp) const ;
121
119
122
- void replaceCall (Value *With) {
123
- CI->replaceAllUsesWith (With);
124
- CI->eraseFromParent ();
120
+ static void replaceCall (Instruction *I, Value *With) {
121
+ I->replaceAllUsesWith (With);
122
+ I->eraseFromParent ();
123
+ }
124
+
125
+ static void replaceCall (FPMathOperator *I, Value *With) {
126
+ replaceCall (cast<Instruction>(I), With);
125
127
}
126
128
127
129
public:
@@ -501,15 +503,14 @@ bool AMDGPULibCalls::sincosUseNative(CallInst *aCI, const FuncInfo &FInfo) {
501
503
DEBUG_WITH_TYPE (" usenative" , dbgs () << " <useNative> replace " << *aCI
502
504
<< " with native version of sin/cos" );
503
505
504
- replaceCall (sinval);
506
+ replaceCall (aCI, sinval);
505
507
return true ;
506
508
}
507
509
}
508
510
return false ;
509
511
}
510
512
511
513
bool AMDGPULibCalls::useNative (CallInst *aCI) {
512
- CI = aCI;
513
514
Function *Callee = aCI->getCalledFunction ();
514
515
if (!Callee || aCI->isNoBuiltin ())
515
516
return false ;
@@ -601,7 +602,6 @@ bool AMDGPULibCalls::fold_read_write_pipe(CallInst *CI, IRBuilder<> &B,
601
602
602
603
// This function returns false if no change; return true otherwise.
603
604
bool AMDGPULibCalls::fold (CallInst *CI, AliasAnalysis *AA) {
604
- this ->CI = CI;
605
605
Function *Callee = CI->getCalledFunction ();
606
606
// Ignore indirect calls.
607
607
if (!Callee || CI->isNoBuiltin ())
@@ -733,7 +733,7 @@ bool AMDGPULibCalls::TDOFold(CallInst *CI, const FuncInfo &FInfo) {
733
733
nval = ConstantDataVector::get (context, tmp);
734
734
}
735
735
LLVM_DEBUG (errs () << " AMDIC: " << *CI << " ---> " << *nval << " \n " );
736
- replaceCall (nval);
736
+ replaceCall (CI, nval);
737
737
return true ;
738
738
}
739
739
} else {
@@ -743,7 +743,7 @@ bool AMDGPULibCalls::TDOFold(CallInst *CI, const FuncInfo &FInfo) {
743
743
if (CF->isExactlyValue (tr[i].input )) {
744
744
Value *nval = ConstantFP::get (CF->getType (), tr[i].result );
745
745
LLVM_DEBUG (errs () << " AMDIC: " << *CI << " ---> " << *nval << " \n " );
746
- replaceCall (nval);
746
+ replaceCall (CI, nval);
747
747
return true ;
748
748
}
749
749
}
@@ -765,7 +765,7 @@ bool AMDGPULibCalls::fold_recip(CallInst *CI, IRBuilder<> &B,
765
765
opr0,
766
766
" recip2div" );
767
767
LLVM_DEBUG (errs () << " AMDIC: " << *CI << " ---> " << *nval << " \n " );
768
- replaceCall (nval);
768
+ replaceCall (CI, nval);
769
769
return true ;
770
770
}
771
771
return false ;
@@ -786,7 +786,7 @@ bool AMDGPULibCalls::fold_divide(CallInst *CI, IRBuilder<> &B,
786
786
Value *nval1 = B.CreateFDiv (ConstantFP::get (opr1->getType (), 1.0 ),
787
787
opr1, " __div2recip" );
788
788
Value *nval = B.CreateFMul (opr0, nval1, " __div2mul" );
789
- replaceCall (nval);
789
+ replaceCall (CI, nval);
790
790
return true ;
791
791
}
792
792
return false ;
@@ -813,8 +813,8 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
813
813
ConstantFP *CF;
814
814
ConstantInt *CINT;
815
815
Type *eltType;
816
- Value *opr0 = CI-> getArgOperand (0 );
817
- Value *opr1 = CI-> getArgOperand (1 );
816
+ Value *opr0 = FPOp-> getOperand (0 );
817
+ Value *opr1 = FPOp-> getOperand (1 );
818
818
ConstantAggregateZero *CZero = dyn_cast<ConstantAggregateZero>(opr1);
819
819
820
820
if (getVecSize (FInfo) == 1 ) {
@@ -841,37 +841,37 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
841
841
842
842
if ((CF && CF->isZero ()) || (CINT && ci_opr1 == 0 ) || CZero) {
843
843
// pow/powr/pown(x, 0) == 1
844
- LLVM_DEBUG (errs () << " AMDIC: " << *CI << " ---> 1\n " );
844
+ LLVM_DEBUG (errs () << " AMDIC: " << *FPOp << " ---> 1\n " );
845
845
Constant *cnval = ConstantFP::get (eltType, 1.0 );
846
846
if (getVecSize (FInfo) > 1 ) {
847
847
cnval = ConstantDataVector::getSplat (getVecSize (FInfo), cnval);
848
848
}
849
- replaceCall (cnval);
849
+ replaceCall (FPOp, cnval);
850
850
return true ;
851
851
}
852
852
if ((CF && CF->isExactlyValue (1.0 )) || (CINT && ci_opr1 == 1 )) {
853
853
// pow/powr/pown(x, 1.0) = x
854
- LLVM_DEBUG (errs () << " AMDIC: " << *CI << " ---> " << *opr0 << " \n " );
855
- replaceCall (opr0);
854
+ LLVM_DEBUG (errs () << " AMDIC: " << *FPOp << " ---> " << *opr0 << " \n " );
855
+ replaceCall (FPOp, opr0);
856
856
return true ;
857
857
}
858
858
if ((CF && CF->isExactlyValue (2.0 )) || (CINT && ci_opr1 == 2 )) {
859
859
// pow/powr/pown(x, 2.0) = x*x
860
- LLVM_DEBUG (errs () << " AMDIC: " << *CI << " ---> " << *opr0 << " * " << *opr0
861
- << " \n " );
860
+ LLVM_DEBUG (errs () << " AMDIC: " << *FPOp << " ---> " << *opr0 << " * "
861
+ << *opr0 << " \n " );
862
862
Value *nval = B.CreateFMul (opr0, opr0, " __pow2" );
863
- replaceCall (nval);
863
+ replaceCall (FPOp, nval);
864
864
return true ;
865
865
}
866
866
if ((CF && CF->isExactlyValue (-1.0 )) || (CINT && ci_opr1 == -1 )) {
867
867
// pow/powr/pown(x, -1.0) = 1.0/x
868
- LLVM_DEBUG (errs () << " AMDIC: " << *CI << " ---> 1 / " << *opr0 << " \n " );
868
+ LLVM_DEBUG (errs () << " AMDIC: " << *FPOp << " ---> 1 / " << *opr0 << " \n " );
869
869
Constant *cnval = ConstantFP::get (eltType, 1.0 );
870
870
if (getVecSize (FInfo) > 1 ) {
871
871
cnval = ConstantDataVector::getSplat (getVecSize (FInfo), cnval);
872
872
}
873
873
Value *nval = B.CreateFDiv (cnval, opr0, " __powrecip" );
874
- replaceCall (nval);
874
+ replaceCall (FPOp, nval);
875
875
return true ;
876
876
}
877
877
@@ -882,11 +882,11 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
882
882
getFunction (M, AMDGPULibFunc (issqrt ? AMDGPULibFunc::EI_SQRT
883
883
: AMDGPULibFunc::EI_RSQRT,
884
884
FInfo))) {
885
- LLVM_DEBUG (errs () << " AMDIC: " << *CI << " ---> " << FInfo.getName ()
885
+ LLVM_DEBUG (errs () << " AMDIC: " << *FPOp << " ---> " << FInfo.getName ()
886
886
<< ' (' << *opr0 << " )\n " );
887
887
Value *nval = CreateCallEx (B,FPExpr, opr0, issqrt ? " __pow2sqrt"
888
888
: " __pow2rsqrt" );
889
- replaceCall (nval);
889
+ replaceCall (FPOp, nval);
890
890
return true ;
891
891
}
892
892
}
@@ -939,10 +939,10 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
939
939
}
940
940
nval = B.CreateFDiv (cnval, nval, " __1powprod" );
941
941
}
942
- LLVM_DEBUG (errs () << " AMDIC: " << *CI << " ---> "
942
+ LLVM_DEBUG (errs () << " AMDIC: " << *FPOp << " ---> "
943
943
<< ((ci_opr1 < 0 ) ? " 1/prod(" : " prod(" ) << *opr0
944
944
<< " )\n " );
945
- replaceCall (nval);
945
+ replaceCall (FPOp, nval);
946
946
return true ;
947
947
}
948
948
@@ -1066,7 +1066,7 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
1066
1066
if (const auto *vTy = dyn_cast<FixedVectorType>(rTy))
1067
1067
nTy = FixedVectorType::get (nTyS, vTy);
1068
1068
unsigned size = nTy->getScalarSizeInBits ();
1069
- opr_n = CI-> getArgOperand (1 );
1069
+ opr_n = FPOp-> getOperand (1 );
1070
1070
if (opr_n->getType ()->isIntegerTy ())
1071
1071
opr_n = B.CreateZExtOrBitCast (opr_n, nTy, " __ytou" );
1072
1072
else
@@ -1078,9 +1078,9 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
1078
1078
nval = B.CreateBitCast (nval, opr0->getType ());
1079
1079
}
1080
1080
1081
- LLVM_DEBUG (errs () << " AMDIC: " << *CI << " ---> "
1081
+ LLVM_DEBUG (errs () << " AMDIC: " << *FPOp << " ---> "
1082
1082
<< " exp2(" << *opr1 << " * log2(" << *opr0 << " ))\n " );
1083
- replaceCall (nval);
1083
+ replaceCall (FPOp, nval);
1084
1084
1085
1085
return true ;
1086
1086
}
@@ -1100,43 +1100,44 @@ bool AMDGPULibCalls::fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B,
1100
1100
}
1101
1101
int ci_opr1 = (int )CINT->getSExtValue ();
1102
1102
if (ci_opr1 == 1 ) { // rootn(x, 1) = x
1103
- LLVM_DEBUG (errs () << " AMDIC: " << *CI << " ---> " << *opr0 << " \n " );
1104
- replaceCall (opr0);
1103
+ LLVM_DEBUG (errs () << " AMDIC: " << *FPOp << " ---> " << *opr0 << " \n " );
1104
+ replaceCall (FPOp, opr0);
1105
1105
return true ;
1106
1106
}
1107
- if (ci_opr1 == 2 ) { // rootn(x, 2) = sqrt(x)
1108
- Module *M = CI->getModule ();
1107
+
1108
+ Module *M = B.GetInsertBlock ()->getModule ();
1109
+ if (ci_opr1 == 2 ) { // rootn(x, 2) = sqrt(x)
1109
1110
if (FunctionCallee FPExpr =
1110
1111
getFunction (M, AMDGPULibFunc (AMDGPULibFunc::EI_SQRT, FInfo))) {
1111
- LLVM_DEBUG (errs () << " AMDIC: " << *CI << " ---> sqrt(" << *opr0 << " )\n " );
1112
+ LLVM_DEBUG (errs () << " AMDIC: " << *FPOp << " ---> sqrt(" << *opr0
1113
+ << " )\n " );
1112
1114
Value *nval = CreateCallEx (B,FPExpr, opr0, " __rootn2sqrt" );
1113
- replaceCall (nval);
1115
+ replaceCall (FPOp, nval);
1114
1116
return true ;
1115
1117
}
1116
1118
} else if (ci_opr1 == 3 ) { // rootn(x, 3) = cbrt(x)
1117
- Module *M = CI->getModule ();
1118
1119
if (FunctionCallee FPExpr =
1119
1120
getFunction (M, AMDGPULibFunc (AMDGPULibFunc::EI_CBRT, FInfo))) {
1120
- LLVM_DEBUG (errs () << " AMDIC: " << *CI << " ---> cbrt(" << *opr0 << " )\n " );
1121
+ LLVM_DEBUG (errs () << " AMDIC: " << *FPOp << " ---> cbrt(" << *opr0
1122
+ << " )\n " );
1121
1123
Value *nval = CreateCallEx (B,FPExpr, opr0, " __rootn2cbrt" );
1122
- replaceCall (nval);
1124
+ replaceCall (FPOp, nval);
1123
1125
return true ;
1124
1126
}
1125
1127
} else if (ci_opr1 == -1 ) { // rootn(x, -1) = 1.0/x
1126
- LLVM_DEBUG (errs () << " AMDIC: " << *CI << " ---> 1.0 / " << *opr0 << " \n " );
1128
+ LLVM_DEBUG (errs () << " AMDIC: " << *FPOp << " ---> 1.0 / " << *opr0 << " \n " );
1127
1129
Value *nval = B.CreateFDiv (ConstantFP::get (opr0->getType (), 1.0 ),
1128
1130
opr0,
1129
1131
" __rootn2div" );
1130
- replaceCall (nval);
1132
+ replaceCall (FPOp, nval);
1131
1133
return true ;
1132
- } else if (ci_opr1 == -2 ) { // rootn(x, -2) = rsqrt(x)
1133
- Module *M = CI->getModule ();
1134
+ } else if (ci_opr1 == -2 ) { // rootn(x, -2) = rsqrt(x)
1134
1135
if (FunctionCallee FPExpr =
1135
1136
getFunction (M, AMDGPULibFunc (AMDGPULibFunc::EI_RSQRT, FInfo))) {
1136
- LLVM_DEBUG (errs () << " AMDIC: " << *CI << " ---> rsqrt(" << *opr0
1137
+ LLVM_DEBUG (errs () << " AMDIC: " << *FPOp << " ---> rsqrt(" << *opr0
1137
1138
<< " )\n " );
1138
1139
Value *nval = CreateCallEx (B,FPExpr, opr0, " __rootn2rsqrt" );
1139
- replaceCall (nval);
1140
+ replaceCall (FPOp, nval);
1140
1141
return true ;
1141
1142
}
1142
1143
}
@@ -1154,23 +1155,23 @@ bool AMDGPULibCalls::fold_fma_mad(CallInst *CI, IRBuilder<> &B,
1154
1155
if ((CF0 && CF0->isZero ()) || (CF1 && CF1->isZero ())) {
1155
1156
// fma/mad(a, b, c) = c if a=0 || b=0
1156
1157
LLVM_DEBUG (errs () << " AMDIC: " << *CI << " ---> " << *opr2 << " \n " );
1157
- replaceCall (opr2);
1158
+ replaceCall (CI, opr2);
1158
1159
return true ;
1159
1160
}
1160
1161
if (CF0 && CF0->isExactlyValue (1 .0f )) {
1161
1162
// fma/mad(a, b, c) = b+c if a=1
1162
1163
LLVM_DEBUG (errs () << " AMDIC: " << *CI << " ---> " << *opr1 << " + " << *opr2
1163
1164
<< " \n " );
1164
1165
Value *nval = B.CreateFAdd (opr1, opr2, " fmaadd" );
1165
- replaceCall (nval);
1166
+ replaceCall (CI, nval);
1166
1167
return true ;
1167
1168
}
1168
1169
if (CF1 && CF1->isExactlyValue (1 .0f )) {
1169
1170
// fma/mad(a, b, c) = a+c if b=1
1170
1171
LLVM_DEBUG (errs () << " AMDIC: " << *CI << " ---> " << *opr0 << " + " << *opr2
1171
1172
<< " \n " );
1172
1173
Value *nval = B.CreateFAdd (opr0, opr2, " fmaadd" );
1173
- replaceCall (nval);
1174
+ replaceCall (CI, nval);
1174
1175
return true ;
1175
1176
}
1176
1177
if (ConstantFP *CF = dyn_cast<ConstantFP>(opr2)) {
@@ -1179,7 +1180,7 @@ bool AMDGPULibCalls::fold_fma_mad(CallInst *CI, IRBuilder<> &B,
1179
1180
LLVM_DEBUG (errs () << " AMDIC: " << *CI << " ---> " << *opr0 << " * "
1180
1181
<< *opr1 << " \n " );
1181
1182
Value *nval = B.CreateFMul (opr0, opr1, " fmamul" );
1182
- replaceCall (nval);
1183
+ replaceCall (CI, nval);
1183
1184
return true ;
1184
1185
}
1185
1186
}
@@ -1205,13 +1206,15 @@ bool AMDGPULibCalls::fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B,
1205
1206
1206
1207
if (getArgType (FInfo) == AMDGPULibFunc::F32 && (getVecSize (FInfo) == 1 ) &&
1207
1208
(FInfo.getPrefix () != AMDGPULibFunc::NATIVE)) {
1209
+ Module *M = B.GetInsertBlock ()->getModule ();
1210
+
1208
1211
if (FunctionCallee FPExpr = getNativeFunction (
1209
- CI-> getModule () , AMDGPULibFunc (AMDGPULibFunc::EI_SQRT, FInfo))) {
1210
- Value *opr0 = CI-> getArgOperand (0 );
1211
- LLVM_DEBUG (errs () << " AMDIC: " << *CI << " ---> "
1212
+ M , AMDGPULibFunc (AMDGPULibFunc::EI_SQRT, FInfo))) {
1213
+ Value *opr0 = FPOp-> getOperand (0 );
1214
+ LLVM_DEBUG (errs () << " AMDIC: " << *FPOp << " ---> "
1212
1215
<< " sqrt(" << *opr0 << " )\n " );
1213
1216
Value *nval = CreateCallEx (B,FPExpr, opr0, " __sqrt" );
1214
- replaceCall (nval);
1217
+ replaceCall (FPOp, nval);
1215
1218
return true ;
1216
1219
}
1217
1220
}
@@ -1231,7 +1234,8 @@ bool AMDGPULibCalls::fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B,
1231
1234
1232
1235
bool const isSin = fInfo .getId () == AMDGPULibFunc::EI_SIN;
1233
1236
1234
- Value *CArgVal = CI->getArgOperand (0 );
1237
+ Value *CArgVal = FPOp->getOperand (0 );
1238
+ CallInst *CI = cast<CallInst>(FPOp);
1235
1239
BasicBlock * const CBB = CI->getParent ();
1236
1240
1237
1241
int const MaxScan = 30 ;
@@ -1247,7 +1251,7 @@ bool AMDGPULibCalls::fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B,
1247
1251
CArgVal->replaceAllUsesWith (AvailableVal);
1248
1252
if (CArgVal->getNumUses () == 0 )
1249
1253
LI->eraseFromParent ();
1250
- CArgVal = CI-> getArgOperand (0 );
1254
+ CArgVal = FPOp-> getOperand (0 );
1251
1255
}
1252
1256
}
1253
1257
}
@@ -1617,12 +1621,12 @@ bool AMDGPULibCalls::evaluateCall(CallInst *aCI, const FuncInfo &FInfo) {
1617
1621
}
1618
1622
}
1619
1623
1620
- LLVMContext &context = CI-> getParent ()-> getParent () ->getContext ();
1624
+ LLVMContext &context = aCI ->getContext ();
1621
1625
Constant *nval0, *nval1;
1622
1626
if (FuncVecSize == 1 ) {
1623
- nval0 = ConstantFP::get (CI ->getType (), DVal0[0 ]);
1627
+ nval0 = ConstantFP::get (aCI ->getType (), DVal0[0 ]);
1624
1628
if (hasTwoResults)
1625
- nval1 = ConstantFP::get (CI ->getType (), DVal1[0 ]);
1629
+ nval1 = ConstantFP::get (aCI ->getType (), DVal1[0 ]);
1626
1630
} else {
1627
1631
if (getArgType (FInfo) == AMDGPULibFunc::F32) {
1628
1632
SmallVector <float , 0 > FVal0, FVal1;
@@ -1653,7 +1657,7 @@ bool AMDGPULibCalls::evaluateCall(CallInst *aCI, const FuncInfo &FInfo) {
1653
1657
new StoreInst (nval1, aCI->getArgOperand (1 ), aCI);
1654
1658
}
1655
1659
1656
- replaceCall (nval0);
1660
+ replaceCall (aCI, nval0);
1657
1661
return true ;
1658
1662
}
1659
1663
0 commit comments