Skip to content

Commit 8f38138

Browse files
committed
AMDGPU: Refactor libcall simplify to help with future refined fast math flag usage
https://reviews.llvm.org/D156678
1 parent 94d5545 commit 8f38138

File tree

1 file changed

+101
-76
lines changed

1 file changed

+101
-76
lines changed

llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp

Lines changed: 101 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class AMDGPULibCalls {
5151

5252
const TargetMachine *TM;
5353

54+
bool UnsafeFPMath = false;
55+
5456
// -fuse-native.
5557
bool AllNative = false;
5658

@@ -73,10 +75,10 @@ class AMDGPULibCalls {
7375
bool fold_divide(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
7476

7577
// pow/powr/pown
76-
bool fold_pow(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
78+
bool fold_pow(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo);
7779

7880
// rootn
79-
bool fold_rootn(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
81+
bool fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo);
8082

8183
// fma/mad
8284
bool fold_fma_mad(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
@@ -90,10 +92,10 @@ class AMDGPULibCalls {
9092
bool evaluateCall(CallInst *aCI, const FuncInfo &FInfo);
9193

9294
// sqrt
93-
bool fold_sqrt(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
95+
bool fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo);
9496

9597
// sin/cos
96-
bool fold_sincos(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo,
98+
bool fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo,
9799
AliasAnalysis *AA);
98100

99101
// __read_pipe/__write_pipe
@@ -113,7 +115,9 @@ class AMDGPULibCalls {
113115
protected:
114116
CallInst *CI;
115117

116-
bool isUnsafeMath(const CallInst *CI) const;
118+
bool isUnsafeMath(const FPMathOperator *FPOp) const;
119+
120+
bool canIncreasePrecisionOfConstantFold(const FPMathOperator *FPOp) const;
117121

118122
void replaceCall(Value *With) {
119123
CI->replaceAllUsesWith(With);
@@ -125,6 +129,7 @@ class AMDGPULibCalls {
125129

126130
bool fold(CallInst *CI, AliasAnalysis *AA = nullptr);
127131

132+
void initFunction(const Function &F);
128133
void initNativeFuncs();
129134

130135
// Replace a normal math function call with that native version
@@ -445,13 +450,18 @@ bool AMDGPULibCalls::parseFunctionName(const StringRef &FMangledName,
445450
return AMDGPULibFunc::parse(FMangledName, FInfo);
446451
}
447452

448-
bool AMDGPULibCalls::isUnsafeMath(const CallInst *CI) const {
449-
if (auto Op = dyn_cast<FPMathOperator>(CI))
450-
if (Op->isFast())
451-
return true;
452-
const Function *F = CI->getParent()->getParent();
453-
Attribute Attr = F->getFnAttribute("unsafe-fp-math");
454-
return Attr.getValueAsBool();
453+
bool AMDGPULibCalls::isUnsafeMath(const FPMathOperator *FPOp) const {
454+
return UnsafeFPMath || FPOp->isFast();
455+
}
456+
457+
bool AMDGPULibCalls::canIncreasePrecisionOfConstantFold(
458+
const FPMathOperator *FPOp) const {
459+
// TODO: Refine to approxFunc or contract
460+
return isUnsafeMath(FPOp);
461+
}
462+
463+
void AMDGPULibCalls::initFunction(const Function &F) {
464+
UnsafeFPMath = F.getFnAttribute("unsafe-fp-math").getValueAsBool();
455465
}
456466

457467
bool AMDGPULibCalls::useNativeFunc(const StringRef F) const {
@@ -620,65 +630,61 @@ bool AMDGPULibCalls::fold(CallInst *CI, AliasAnalysis *AA) {
620630
if (TDOFold(CI, FInfo))
621631
return true;
622632

623-
// Under unsafe-math, evaluate calls if possible.
624-
// According to Brian Sumner, we can do this for all f32 function calls
625-
// using host's double function calls.
626-
if (isUnsafeMath(CI) && evaluateCall(CI, FInfo))
627-
return true;
633+
if (FPMathOperator *FPOp = dyn_cast<FPMathOperator>(CI)) {
634+
// Under unsafe-math, evaluate calls if possible.
635+
// According to Brian Sumner, we can do this for all f32 function calls
636+
// using host's double function calls.
637+
if (canIncreasePrecisionOfConstantFold(FPOp) && evaluateCall(CI, FInfo))
638+
return true;
628639

629-
// Copy fast flags from the original call.
630-
if (const FPMathOperator *FPOp = dyn_cast<const FPMathOperator>(CI))
640+
// Copy fast flags from the original call.
631641
B.setFastMathFlags(FPOp->getFastMathFlags());
632642

633-
// Specialized optimizations for each function call
634-
switch (FInfo.getId()) {
635-
case AMDGPULibFunc::EI_RECIP:
636-
// skip vector function
637-
assert ((FInfo.getPrefix() == AMDGPULibFunc::NATIVE ||
638-
FInfo.getPrefix() == AMDGPULibFunc::HALF) &&
639-
"recip must be an either native or half function");
640-
return (getVecSize(FInfo) != 1) ? false : fold_recip(CI, B, FInfo);
641-
642-
case AMDGPULibFunc::EI_DIVIDE:
643-
// skip vector function
644-
assert ((FInfo.getPrefix() == AMDGPULibFunc::NATIVE ||
645-
FInfo.getPrefix() == AMDGPULibFunc::HALF) &&
646-
"divide must be an either native or half function");
647-
return (getVecSize(FInfo) != 1) ? false : fold_divide(CI, B, FInfo);
648-
649-
case AMDGPULibFunc::EI_POW:
650-
case AMDGPULibFunc::EI_POWR:
651-
case AMDGPULibFunc::EI_POWN:
652-
return fold_pow(CI, B, FInfo);
653-
654-
case AMDGPULibFunc::EI_ROOTN:
655-
// skip vector function
656-
return (getVecSize(FInfo) != 1) ? false : fold_rootn(CI, B, FInfo);
657-
658-
case AMDGPULibFunc::EI_FMA:
659-
case AMDGPULibFunc::EI_MAD:
660-
case AMDGPULibFunc::EI_NFMA:
661-
// skip vector function
662-
return (getVecSize(FInfo) != 1) ? false : fold_fma_mad(CI, B, FInfo);
663-
664-
case AMDGPULibFunc::EI_SQRT:
665-
return isUnsafeMath(CI) && fold_sqrt(CI, B, FInfo);
666-
case AMDGPULibFunc::EI_COS:
667-
case AMDGPULibFunc::EI_SIN:
668-
if ((getArgType(FInfo) == AMDGPULibFunc::F32 ||
669-
getArgType(FInfo) == AMDGPULibFunc::F64)
670-
&& (FInfo.getPrefix() == AMDGPULibFunc::NOPFX))
671-
return fold_sincos(CI, B, FInfo, AA);
672-
673-
break;
674-
case AMDGPULibFunc::EI_READ_PIPE_2:
675-
case AMDGPULibFunc::EI_READ_PIPE_4:
676-
case AMDGPULibFunc::EI_WRITE_PIPE_2:
677-
case AMDGPULibFunc::EI_WRITE_PIPE_4:
678-
return fold_read_write_pipe(CI, B, FInfo);
679-
680-
default:
681-
break;
643+
// Specialized optimizations for each function call
644+
switch (FInfo.getId()) {
645+
case AMDGPULibFunc::EI_POW:
646+
case AMDGPULibFunc::EI_POWR:
647+
case AMDGPULibFunc::EI_POWN:
648+
return fold_pow(FPOp, B, FInfo);
649+
case AMDGPULibFunc::EI_ROOTN:
650+
return fold_rootn(FPOp, B, FInfo);
651+
case AMDGPULibFunc::EI_SQRT:
652+
return fold_sqrt(FPOp, B, FInfo);
653+
case AMDGPULibFunc::EI_COS:
654+
case AMDGPULibFunc::EI_SIN:
655+
return fold_sincos(FPOp, B, FInfo, AA);
656+
case AMDGPULibFunc::EI_RECIP:
657+
// skip vector function
658+
assert((FInfo.getPrefix() == AMDGPULibFunc::NATIVE ||
659+
FInfo.getPrefix() == AMDGPULibFunc::HALF) &&
660+
"recip must be an either native or half function");
661+
return (getVecSize(FInfo) != 1) ? false : fold_recip(CI, B, FInfo);
662+
663+
case AMDGPULibFunc::EI_DIVIDE:
664+
// skip vector function
665+
assert((FInfo.getPrefix() == AMDGPULibFunc::NATIVE ||
666+
FInfo.getPrefix() == AMDGPULibFunc::HALF) &&
667+
"divide must be an either native or half function");
668+
return (getVecSize(FInfo) != 1) ? false : fold_divide(CI, B, FInfo);
669+
case AMDGPULibFunc::EI_FMA:
670+
case AMDGPULibFunc::EI_MAD:
671+
case AMDGPULibFunc::EI_NFMA:
672+
// skip vector function
673+
return (getVecSize(FInfo) != 1) ? false : fold_fma_mad(CI, B, FInfo);
674+
default:
675+
break;
676+
}
677+
} else {
678+
// Specialized optimizations for each function call
679+
switch (FInfo.getId()) {
680+
case AMDGPULibFunc::EI_READ_PIPE_2:
681+
case AMDGPULibFunc::EI_READ_PIPE_4:
682+
case AMDGPULibFunc::EI_WRITE_PIPE_2:
683+
case AMDGPULibFunc::EI_WRITE_PIPE_4:
684+
return fold_read_write_pipe(CI, B, FInfo);
685+
default:
686+
break;
687+
}
682688
}
683689

684690
return false;
@@ -796,7 +802,7 @@ static double log2(double V) {
796802
}
797803
}
798804

799-
bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B,
805+
bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
800806
const FuncInfo &FInfo) {
801807
assert((FInfo.getId() == AMDGPULibFunc::EI_POW ||
802808
FInfo.getId() == AMDGPULibFunc::EI_POWR ||
@@ -827,7 +833,7 @@ bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B,
827833
}
828834

829835
// No unsafe math , no constant argument, do nothing
830-
if (!isUnsafeMath(CI) && !CF && !CINT && !CZero)
836+
if (!isUnsafeMath(FPOp) && !CF && !CINT && !CZero)
831837
return false;
832838

833839
// 0x1111111 means that we don't do anything for this call.
@@ -885,7 +891,7 @@ bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B,
885891
}
886892
}
887893

888-
if (!isUnsafeMath(CI))
894+
if (!isUnsafeMath(FPOp))
889895
return false;
890896

891897
// Unsafe Math optimization
@@ -1079,10 +1085,14 @@ bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B,
10791085
return true;
10801086
}
10811087

1082-
bool AMDGPULibCalls::fold_rootn(CallInst *CI, IRBuilder<> &B,
1088+
bool AMDGPULibCalls::fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B,
10831089
const FuncInfo &FInfo) {
1084-
Value *opr0 = CI->getArgOperand(0);
1085-
Value *opr1 = CI->getArgOperand(1);
1090+
// skip vector function
1091+
if (getVecSize(FInfo) != 1)
1092+
return false;
1093+
1094+
Value *opr0 = FPOp->getOperand(0);
1095+
Value *opr1 = FPOp->getOperand(1);
10861096

10871097
ConstantInt *CINT = dyn_cast<ConstantInt>(opr1);
10881098
if (!CINT) {
@@ -1188,8 +1198,11 @@ FunctionCallee AMDGPULibCalls::getNativeFunction(Module *M,
11881198
}
11891199

11901200
// fold sqrt -> native_sqrt (x)
1191-
bool AMDGPULibCalls::fold_sqrt(CallInst *CI, IRBuilder<> &B,
1201+
bool AMDGPULibCalls::fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B,
11921202
const FuncInfo &FInfo) {
1203+
if (!isUnsafeMath(FPOp))
1204+
return false;
1205+
11931206
if (getArgType(FInfo) == AMDGPULibFunc::F32 && (getVecSize(FInfo) == 1) &&
11941207
(FInfo.getPrefix() != AMDGPULibFunc::NATIVE)) {
11951208
if (FunctionCallee FPExpr = getNativeFunction(
@@ -1206,10 +1219,16 @@ bool AMDGPULibCalls::fold_sqrt(CallInst *CI, IRBuilder<> &B,
12061219
}
12071220

12081221
// fold sin, cos -> sincos.
1209-
bool AMDGPULibCalls::fold_sincos(CallInst *CI, IRBuilder<> &B,
1222+
bool AMDGPULibCalls::fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B,
12101223
const FuncInfo &fInfo, AliasAnalysis *AA) {
12111224
assert(fInfo.getId() == AMDGPULibFunc::EI_SIN ||
12121225
fInfo.getId() == AMDGPULibFunc::EI_COS);
1226+
1227+
if ((getArgType(fInfo) != AMDGPULibFunc::F32 &&
1228+
getArgType(fInfo) != AMDGPULibFunc::F64) ||
1229+
fInfo.getPrefix() != AMDGPULibFunc::NOPFX)
1230+
return false;
1231+
12131232
bool const isSin = fInfo.getId() == AMDGPULibFunc::EI_SIN;
12141233

12151234
Value *CArgVal = CI->getArgOperand(0);
@@ -1651,6 +1670,8 @@ bool AMDGPUSimplifyLibCalls::runOnFunction(Function &F) {
16511670
if (skipFunction(F))
16521671
return false;
16531672

1673+
Simplifier.initFunction(F);
1674+
16541675
bool Changed = false;
16551676
auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
16561677

@@ -1675,6 +1696,7 @@ PreservedAnalyses AMDGPUSimplifyLibCallsPass::run(Function &F,
16751696
FunctionAnalysisManager &AM) {
16761697
AMDGPULibCalls Simplifier(&TM);
16771698
Simplifier.initNativeFuncs();
1699+
Simplifier.initFunction(F);
16781700

16791701
bool Changed = false;
16801702
auto AA = &AM.getResult<AAManager>(F);
@@ -1701,6 +1723,8 @@ bool AMDGPUUseNativeCalls::runOnFunction(Function &F) {
17011723
if (skipFunction(F) || UseNative.empty())
17021724
return false;
17031725

1726+
Simplifier.initFunction(F);
1727+
17041728
bool Changed = false;
17051729
for (auto &BB : F) {
17061730
for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ) {
@@ -1721,6 +1745,7 @@ PreservedAnalyses AMDGPUUseNativeCallsPass::run(Function &F,
17211745

17221746
AMDGPULibCalls Simplifier;
17231747
Simplifier.initNativeFuncs();
1748+
Simplifier.initFunction(F);
17241749

17251750
bool Changed = false;
17261751
for (auto &BB : F) {

0 commit comments

Comments
 (0)