Skip to content

Commit c2c22c6

Browse files
committed
AMDGPU: Don't store current instruction in AMDGPULibCalls member
This was adding confusing global state which was shadowed most of the time. https://reviews.llvm.org/D156680
1 parent d517117 commit c2c22c6

File tree

1 file changed

+63
-59
lines changed

1 file changed

+63
-59
lines changed

llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp

Lines changed: 63 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,17 @@ class AMDGPULibCalls {
113113
FunctionCallee getNativeFunction(Module *M, const FuncInfo &FInfo);
114114

115115
protected:
116-
CallInst *CI;
117-
118116
bool isUnsafeMath(const FPMathOperator *FPOp) const;
119117

120118
bool canIncreasePrecisionOfConstantFold(const FPMathOperator *FPOp) const;
121119

122-
void replaceCall(Value *With) {
123-
CI->replaceAllUsesWith(With);
124-
CI->eraseFromParent();
120+
static void replaceCall(Instruction *I, Value *With) {
121+
I->replaceAllUsesWith(With);
122+
I->eraseFromParent();
123+
}
124+
125+
static void replaceCall(FPMathOperator *I, Value *With) {
126+
replaceCall(cast<Instruction>(I), With);
125127
}
126128

127129
public:
@@ -501,15 +503,14 @@ bool AMDGPULibCalls::sincosUseNative(CallInst *aCI, const FuncInfo &FInfo) {
501503
DEBUG_WITH_TYPE("usenative", dbgs() << "<useNative> replace " << *aCI
502504
<< " with native version of sin/cos");
503505

504-
replaceCall(sinval);
506+
replaceCall(aCI, sinval);
505507
return true;
506508
}
507509
}
508510
return false;
509511
}
510512

511513
bool AMDGPULibCalls::useNative(CallInst *aCI) {
512-
CI = aCI;
513514
Function *Callee = aCI->getCalledFunction();
514515
if (!Callee || aCI->isNoBuiltin())
515516
return false;
@@ -601,7 +602,6 @@ bool AMDGPULibCalls::fold_read_write_pipe(CallInst *CI, IRBuilder<> &B,
601602

602603
// This function returns false if no change; return true otherwise.
603604
bool AMDGPULibCalls::fold(CallInst *CI, AliasAnalysis *AA) {
604-
this->CI = CI;
605605
Function *Callee = CI->getCalledFunction();
606606
// Ignore indirect calls.
607607
if (!Callee || CI->isNoBuiltin())
@@ -733,7 +733,7 @@ bool AMDGPULibCalls::TDOFold(CallInst *CI, const FuncInfo &FInfo) {
733733
nval = ConstantDataVector::get(context, tmp);
734734
}
735735
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n");
736-
replaceCall(nval);
736+
replaceCall(CI, nval);
737737
return true;
738738
}
739739
} else {
@@ -743,7 +743,7 @@ bool AMDGPULibCalls::TDOFold(CallInst *CI, const FuncInfo &FInfo) {
743743
if (CF->isExactlyValue(tr[i].input)) {
744744
Value *nval = ConstantFP::get(CF->getType(), tr[i].result);
745745
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n");
746-
replaceCall(nval);
746+
replaceCall(CI, nval);
747747
return true;
748748
}
749749
}
@@ -765,7 +765,7 @@ bool AMDGPULibCalls::fold_recip(CallInst *CI, IRBuilder<> &B,
765765
opr0,
766766
"recip2div");
767767
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n");
768-
replaceCall(nval);
768+
replaceCall(CI, nval);
769769
return true;
770770
}
771771
return false;
@@ -786,7 +786,7 @@ bool AMDGPULibCalls::fold_divide(CallInst *CI, IRBuilder<> &B,
786786
Value *nval1 = B.CreateFDiv(ConstantFP::get(opr1->getType(), 1.0),
787787
opr1, "__div2recip");
788788
Value *nval = B.CreateFMul(opr0, nval1, "__div2mul");
789-
replaceCall(nval);
789+
replaceCall(CI, nval);
790790
return true;
791791
}
792792
return false;
@@ -813,8 +813,8 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
813813
ConstantFP *CF;
814814
ConstantInt *CINT;
815815
Type *eltType;
816-
Value *opr0 = CI->getArgOperand(0);
817-
Value *opr1 = CI->getArgOperand(1);
816+
Value *opr0 = FPOp->getOperand(0);
817+
Value *opr1 = FPOp->getOperand(1);
818818
ConstantAggregateZero *CZero = dyn_cast<ConstantAggregateZero>(opr1);
819819

820820
if (getVecSize(FInfo) == 1) {
@@ -841,37 +841,37 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
841841

842842
if ((CF && CF->isZero()) || (CINT && ci_opr1 == 0) || CZero) {
843843
// pow/powr/pown(x, 0) == 1
844-
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> 1\n");
844+
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1\n");
845845
Constant *cnval = ConstantFP::get(eltType, 1.0);
846846
if (getVecSize(FInfo) > 1) {
847847
cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
848848
}
849-
replaceCall(cnval);
849+
replaceCall(FPOp, cnval);
850850
return true;
851851
}
852852
if ((CF && CF->isExactlyValue(1.0)) || (CINT && ci_opr1 == 1)) {
853853
// pow/powr/pown(x, 1.0) = x
854-
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << "\n");
855-
replaceCall(opr0);
854+
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << "\n");
855+
replaceCall(FPOp, opr0);
856856
return true;
857857
}
858858
if ((CF && CF->isExactlyValue(2.0)) || (CINT && ci_opr1 == 2)) {
859859
// pow/powr/pown(x, 2.0) = x*x
860-
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << " * " << *opr0
861-
<< "\n");
860+
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << " * "
861+
<< *opr0 << "\n");
862862
Value *nval = B.CreateFMul(opr0, opr0, "__pow2");
863-
replaceCall(nval);
863+
replaceCall(FPOp, nval);
864864
return true;
865865
}
866866
if ((CF && CF->isExactlyValue(-1.0)) || (CINT && ci_opr1 == -1)) {
867867
// pow/powr/pown(x, -1.0) = 1.0/x
868-
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> 1 / " << *opr0 << "\n");
868+
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1 / " << *opr0 << "\n");
869869
Constant *cnval = ConstantFP::get(eltType, 1.0);
870870
if (getVecSize(FInfo) > 1) {
871871
cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
872872
}
873873
Value *nval = B.CreateFDiv(cnval, opr0, "__powrecip");
874-
replaceCall(nval);
874+
replaceCall(FPOp, nval);
875875
return true;
876876
}
877877

@@ -882,11 +882,11 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
882882
getFunction(M, AMDGPULibFunc(issqrt ? AMDGPULibFunc::EI_SQRT
883883
: AMDGPULibFunc::EI_RSQRT,
884884
FInfo))) {
885-
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << FInfo.getName()
885+
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << FInfo.getName()
886886
<< '(' << *opr0 << ")\n");
887887
Value *nval = CreateCallEx(B,FPExpr, opr0, issqrt ? "__pow2sqrt"
888888
: "__pow2rsqrt");
889-
replaceCall(nval);
889+
replaceCall(FPOp, nval);
890890
return true;
891891
}
892892
}
@@ -939,10 +939,10 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
939939
}
940940
nval = B.CreateFDiv(cnval, nval, "__1powprod");
941941
}
942-
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> "
942+
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> "
943943
<< ((ci_opr1 < 0) ? "1/prod(" : "prod(") << *opr0
944944
<< ")\n");
945-
replaceCall(nval);
945+
replaceCall(FPOp, nval);
946946
return true;
947947
}
948948

@@ -1066,7 +1066,7 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
10661066
if (const auto *vTy = dyn_cast<FixedVectorType>(rTy))
10671067
nTy = FixedVectorType::get(nTyS, vTy);
10681068
unsigned size = nTy->getScalarSizeInBits();
1069-
opr_n = CI->getArgOperand(1);
1069+
opr_n = FPOp->getOperand(1);
10701070
if (opr_n->getType()->isIntegerTy())
10711071
opr_n = B.CreateZExtOrBitCast(opr_n, nTy, "__ytou");
10721072
else
@@ -1078,9 +1078,9 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
10781078
nval = B.CreateBitCast(nval, opr0->getType());
10791079
}
10801080

1081-
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> "
1081+
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> "
10821082
<< "exp2(" << *opr1 << " * log2(" << *opr0 << "))\n");
1083-
replaceCall(nval);
1083+
replaceCall(FPOp, nval);
10841084

10851085
return true;
10861086
}
@@ -1100,43 +1100,44 @@ bool AMDGPULibCalls::fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B,
11001100
}
11011101
int ci_opr1 = (int)CINT->getSExtValue();
11021102
if (ci_opr1 == 1) { // rootn(x, 1) = x
1103-
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << "\n");
1104-
replaceCall(opr0);
1103+
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << "\n");
1104+
replaceCall(FPOp, opr0);
11051105
return true;
11061106
}
1107-
if (ci_opr1 == 2) { // rootn(x, 2) = sqrt(x)
1108-
Module *M = CI->getModule();
1107+
1108+
Module *M = B.GetInsertBlock()->getModule();
1109+
if (ci_opr1 == 2) { // rootn(x, 2) = sqrt(x)
11091110
if (FunctionCallee FPExpr =
11101111
getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) {
1111-
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> sqrt(" << *opr0 << ")\n");
1112+
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> sqrt(" << *opr0
1113+
<< ")\n");
11121114
Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2sqrt");
1113-
replaceCall(nval);
1115+
replaceCall(FPOp, nval);
11141116
return true;
11151117
}
11161118
} else if (ci_opr1 == 3) { // rootn(x, 3) = cbrt(x)
1117-
Module *M = CI->getModule();
11181119
if (FunctionCallee FPExpr =
11191120
getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_CBRT, FInfo))) {
1120-
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> cbrt(" << *opr0 << ")\n");
1121+
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> cbrt(" << *opr0
1122+
<< ")\n");
11211123
Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2cbrt");
1122-
replaceCall(nval);
1124+
replaceCall(FPOp, nval);
11231125
return true;
11241126
}
11251127
} else if (ci_opr1 == -1) { // rootn(x, -1) = 1.0/x
1126-
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> 1.0 / " << *opr0 << "\n");
1128+
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1.0 / " << *opr0 << "\n");
11271129
Value *nval = B.CreateFDiv(ConstantFP::get(opr0->getType(), 1.0),
11281130
opr0,
11291131
"__rootn2div");
1130-
replaceCall(nval);
1132+
replaceCall(FPOp, nval);
11311133
return true;
1132-
} else if (ci_opr1 == -2) { // rootn(x, -2) = rsqrt(x)
1133-
Module *M = CI->getModule();
1134+
} else if (ci_opr1 == -2) { // rootn(x, -2) = rsqrt(x)
11341135
if (FunctionCallee FPExpr =
11351136
getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_RSQRT, FInfo))) {
1136-
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> rsqrt(" << *opr0
1137+
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> rsqrt(" << *opr0
11371138
<< ")\n");
11381139
Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2rsqrt");
1139-
replaceCall(nval);
1140+
replaceCall(FPOp, nval);
11401141
return true;
11411142
}
11421143
}
@@ -1154,23 +1155,23 @@ bool AMDGPULibCalls::fold_fma_mad(CallInst *CI, IRBuilder<> &B,
11541155
if ((CF0 && CF0->isZero()) || (CF1 && CF1->isZero())) {
11551156
// fma/mad(a, b, c) = c if a=0 || b=0
11561157
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr2 << "\n");
1157-
replaceCall(opr2);
1158+
replaceCall(CI, opr2);
11581159
return true;
11591160
}
11601161
if (CF0 && CF0->isExactlyValue(1.0f)) {
11611162
// fma/mad(a, b, c) = b+c if a=1
11621163
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr1 << " + " << *opr2
11631164
<< "\n");
11641165
Value *nval = B.CreateFAdd(opr1, opr2, "fmaadd");
1165-
replaceCall(nval);
1166+
replaceCall(CI, nval);
11661167
return true;
11671168
}
11681169
if (CF1 && CF1->isExactlyValue(1.0f)) {
11691170
// fma/mad(a, b, c) = a+c if b=1
11701171
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << " + " << *opr2
11711172
<< "\n");
11721173
Value *nval = B.CreateFAdd(opr0, opr2, "fmaadd");
1173-
replaceCall(nval);
1174+
replaceCall(CI, nval);
11741175
return true;
11751176
}
11761177
if (ConstantFP *CF = dyn_cast<ConstantFP>(opr2)) {
@@ -1179,7 +1180,7 @@ bool AMDGPULibCalls::fold_fma_mad(CallInst *CI, IRBuilder<> &B,
11791180
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << " * "
11801181
<< *opr1 << "\n");
11811182
Value *nval = B.CreateFMul(opr0, opr1, "fmamul");
1182-
replaceCall(nval);
1183+
replaceCall(CI, nval);
11831184
return true;
11841185
}
11851186
}
@@ -1205,13 +1206,15 @@ bool AMDGPULibCalls::fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B,
12051206

12061207
if (getArgType(FInfo) == AMDGPULibFunc::F32 && (getVecSize(FInfo) == 1) &&
12071208
(FInfo.getPrefix() != AMDGPULibFunc::NATIVE)) {
1209+
Module *M = B.GetInsertBlock()->getModule();
1210+
12081211
if (FunctionCallee FPExpr = getNativeFunction(
1209-
CI->getModule(), AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) {
1210-
Value *opr0 = CI->getArgOperand(0);
1211-
LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> "
1212+
M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) {
1213+
Value *opr0 = FPOp->getOperand(0);
1214+
LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> "
12121215
<< "sqrt(" << *opr0 << ")\n");
12131216
Value *nval = CreateCallEx(B,FPExpr, opr0, "__sqrt");
1214-
replaceCall(nval);
1217+
replaceCall(FPOp, nval);
12151218
return true;
12161219
}
12171220
}
@@ -1231,7 +1234,8 @@ bool AMDGPULibCalls::fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B,
12311234

12321235
bool const isSin = fInfo.getId() == AMDGPULibFunc::EI_SIN;
12331236

1234-
Value *CArgVal = CI->getArgOperand(0);
1237+
Value *CArgVal = FPOp->getOperand(0);
1238+
CallInst *CI = cast<CallInst>(FPOp);
12351239
BasicBlock * const CBB = CI->getParent();
12361240

12371241
int const MaxScan = 30;
@@ -1247,7 +1251,7 @@ bool AMDGPULibCalls::fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B,
12471251
CArgVal->replaceAllUsesWith(AvailableVal);
12481252
if (CArgVal->getNumUses() == 0)
12491253
LI->eraseFromParent();
1250-
CArgVal = CI->getArgOperand(0);
1254+
CArgVal = FPOp->getOperand(0);
12511255
}
12521256
}
12531257
}
@@ -1617,12 +1621,12 @@ bool AMDGPULibCalls::evaluateCall(CallInst *aCI, const FuncInfo &FInfo) {
16171621
}
16181622
}
16191623

1620-
LLVMContext &context = CI->getParent()->getParent()->getContext();
1624+
LLVMContext &context = aCI->getContext();
16211625
Constant *nval0, *nval1;
16221626
if (FuncVecSize == 1) {
1623-
nval0 = ConstantFP::get(CI->getType(), DVal0[0]);
1627+
nval0 = ConstantFP::get(aCI->getType(), DVal0[0]);
16241628
if (hasTwoResults)
1625-
nval1 = ConstantFP::get(CI->getType(), DVal1[0]);
1629+
nval1 = ConstantFP::get(aCI->getType(), DVal1[0]);
16261630
} else {
16271631
if (getArgType(FInfo) == AMDGPULibFunc::F32) {
16281632
SmallVector <float, 0> FVal0, FVal1;
@@ -1653,7 +1657,7 @@ bool AMDGPULibCalls::evaluateCall(CallInst *aCI, const FuncInfo &FInfo) {
16531657
new StoreInst(nval1, aCI->getArgOperand(1), aCI);
16541658
}
16551659

1656-
replaceCall(nval0);
1660+
replaceCall(aCI, nval0);
16571661
return true;
16581662
}
16591663

0 commit comments

Comments
 (0)