Skip to content

Commit 42c6e42

Browse files
committed
AMDGPU: Handle multiple uses when matching sincos
Match how the generic implementation handles this. We now will leave behind the dead other user for later passes to deal with. https://reviews.llvm.org/D156707
1 parent bbc0f99 commit 42c6e42

File tree

5 files changed

+148
-98
lines changed

5 files changed

+148
-98
lines changed

llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp

Lines changed: 91 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,12 @@ class AMDGPULibCalls {
8282
// sqrt
8383
bool fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo);
8484

85-
bool insertSinCos(CallInst *Sin, CallInst *Cos, IRBuilder<> &B,
86-
const FuncInfo &FInfo);
85+
/// Insert a value to sincos function \p Fsincos. Returns (value of sin, value
86+
/// of cos, sincos call).
87+
std::tuple<Value *, Value *, Value *> insertSinCos(Value *Arg,
88+
FastMathFlags FMF,
89+
IRBuilder<> &B,
90+
FunctionCallee Fsincos);
8791

8892
// sin/cos
8993
bool fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo);
@@ -1041,40 +1045,24 @@ bool AMDGPULibCalls::fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B,
10411045
return false;
10421046
}
10431047

1044-
bool AMDGPULibCalls::insertSinCos(CallInst *Sin, CallInst *Cos, IRBuilder<> &B,
1045-
const FuncInfo &fInfo) {
1046-
Value *Arg = Sin->getOperand(0);
1047-
assert(Arg == Cos->getOperand(0));
1048-
1048+
std::tuple<Value *, Value *, Value *>
1049+
AMDGPULibCalls::insertSinCos(Value *Arg, FastMathFlags FMF, IRBuilder<> &B,
1050+
FunctionCallee Fsincos) {
1051+
DebugLoc DL = B.getCurrentDebugLocation();
10491052
Function *F = B.GetInsertBlock()->getParent();
1050-
Module *M = F->getParent();
1051-
// Merge the sin and cos.
1052-
1053-
// for OpenCL 2.0 we have only generic implementation of sincos
1054-
// function.
1055-
// FIXME: This is not true anymore
1056-
AMDGPULibFunc nf(AMDGPULibFunc::EI_SINCOS, fInfo);
1057-
nf.getLeads()[0].PtrKind =
1058-
AMDGPULibFunc::getEPtrKindFromAddrSpace(AMDGPUAS::FLAT_ADDRESS);
1059-
FunctionCallee Fsincos = getFunction(M, nf);
1060-
if (!Fsincos)
1061-
return false;
1062-
10631053
B.SetInsertPointPastAllocas(F);
10641054

1065-
DILocation *MergedDebugLoc =
1066-
DILocation::getMergedLocation(Sin->getDebugLoc(), Cos->getDebugLoc());
1067-
B.SetCurrentDebugLocation(MergedDebugLoc);
1068-
1069-
AllocaInst *Alloc = B.CreateAlloca(Sin->getType(), nullptr, "__sincos_");
1055+
AllocaInst *Alloc = B.CreateAlloca(Arg->getType(), nullptr, "__sincos_");
10701056

10711057
if (Instruction *ArgInst = dyn_cast<Instruction>(Arg)) {
10721058
// If the argument is an instruction, it must dominate all uses so put our
10731059
// sincos call there. Otherwise, right after the allocas works well enough
10741060
// if it's an argument or constant.
10751061

10761062
B.SetInsertPoint(ArgInst->getParent(), ++ArgInst->getIterator());
1077-
B.SetCurrentDebugLocation(MergedDebugLoc);
1063+
1064+
// SetInsertPoint unwelcomely always tries to set the debug loc.
1065+
B.SetCurrentDebugLocation(DL);
10781066
}
10791067

10801068
Value *P = Alloc;
@@ -1085,25 +1073,12 @@ bool AMDGPULibCalls::insertSinCos(CallInst *Sin, CallInst *Cos, IRBuilder<> &B,
10851073
if (PTy->getPointerAddressSpace() != AMDGPUAS::PRIVATE_ADDRESS)
10861074
P = B.CreateAddrSpaceCast(Alloc, PTy);
10871075

1088-
// Intersect the two sets of flags.
1089-
FastMathFlags FMF = cast<FPMathOperator>(Sin)->getFastMathFlags();
1090-
FMF &= cast<FPMathOperator>(Cos)->getFastMathFlags();
1091-
B.setFastMathFlags(FMF);
1092-
1093-
CallInst *Call = CreateCallEx2(B, Fsincos, Arg, P);
1094-
LoadInst *Reload = B.CreateLoad(Alloc->getAllocatedType(), Alloc);
1095-
Reload->setDebugLoc(Cos->getDebugLoc());
1096-
1097-
LLVM_DEBUG(errs() << "AMDIC: fold_sincos (" << *Sin << ", " << *Cos
1098-
<< ") with " << *Call << '\n');
1099-
1100-
Sin->replaceAllUsesWith(Call);
1101-
Sin->eraseFromParent();
1102-
1103-
Cos->replaceAllUsesWith(Reload);
1104-
Cos->eraseFromParent();
1076+
CallInst *SinCos = CreateCallEx2(B, Fsincos, Arg, P);
11051077

1106-
return true;
1078+
// TODO: Is it worth trying to preserve the location for the cos calls for the
1079+
// load?
1080+
LoadInst *LoadCos = B.CreateLoad(Alloc->getAllocatedType(), Alloc);
1081+
return {SinCos, LoadCos, SinCos};
11071082
}
11081083

11091084
// fold sin, cos -> sincos.
@@ -1121,33 +1096,92 @@ bool AMDGPULibCalls::fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B,
11211096

11221097
Value *CArgVal = FPOp->getOperand(0);
11231098
CallInst *CI = cast<CallInst>(FPOp);
1124-
bool Changed = false;
11251099

1100+
Function *F = B.GetInsertBlock()->getParent();
1101+
Module *M = F->getParent();
1102+
1103+
// Merge the sin and cos.
1104+
1105+
// for OpenCL 2.0 we have only generic implementation of sincos
1106+
// function.
1107+
// FIXME: This is not true anymore
1108+
AMDGPULibFunc SinCosLibFunc(AMDGPULibFunc::EI_SINCOS, fInfo);
1109+
SinCosLibFunc.getLeads()[0].PtrKind =
1110+
AMDGPULibFunc::getEPtrKindFromAddrSpace(AMDGPUAS::FLAT_ADDRESS);
1111+
FunctionCallee FSinCos = getFunction(M, SinCosLibFunc);
1112+
if (!FSinCos)
1113+
return false;
1114+
1115+
SmallVector<CallInst *> SinCalls;
1116+
SmallVector<CallInst *> CosCalls;
1117+
SmallVector<CallInst *> SinCosCalls;
11261118
FuncInfo PartnerInfo(isSin ? AMDGPULibFunc::EI_COS : AMDGPULibFunc::EI_SIN,
11271119
fInfo);
11281120
const std::string PairName = PartnerInfo.mangle();
11291121

1130-
CallInst *UI = nullptr;
1122+
StringRef SinName = isSin ? CI->getCalledFunction()->getName() : PairName;
1123+
StringRef CosName = isSin ? PairName : CI->getCalledFunction()->getName();
1124+
const std::string SinCosName = SinCosLibFunc.mangle();
1125+
1126+
// Intersect the two sets of flags.
1127+
FastMathFlags FMF = FPOp->getFastMathFlags();
1128+
MDNode *FPMath = CI->getMetadata(LLVMContext::MD_fpmath);
1129+
1130+
SmallVector<DILocation *> MergeDbgLocs = {CI->getDebugLoc()};
11311131

1132-
// TODO: Handle repeated uses, the generic implementation does.
11331132
for (User* U : CArgVal->users()) {
11341133
CallInst *XI = dyn_cast<CallInst>(U);
1135-
if (!XI || XI->isNoBuiltin())
1134+
if (!XI || XI->getFunction() != F || XI->isNoBuiltin())
11361135
continue;
11371136

11381137
Function *UCallee = XI->getCalledFunction();
1139-
if (UCallee && UCallee->getName().equals(PairName))
1140-
UI = XI;
1141-
else if (UI)
1142-
return Changed;
1138+
if (!UCallee)
1139+
continue;
1140+
1141+
bool Handled = true;
1142+
1143+
if (UCallee->getName() == SinName)
1144+
SinCalls.push_back(XI);
1145+
else if (UCallee->getName() == CosName)
1146+
CosCalls.push_back(XI);
1147+
else if (UCallee->getName() == SinCosName)
1148+
SinCosCalls.push_back(XI);
1149+
else
1150+
Handled = false;
1151+
1152+
if (Handled) {
1153+
MergeDbgLocs.push_back(XI->getDebugLoc());
1154+
auto *OtherOp = cast<FPMathOperator>(XI);
1155+
FMF &= OtherOp->getFastMathFlags();
1156+
FPMath = MDNode::getMostGenericFPMath(
1157+
FPMath, XI->getMetadata(LLVMContext::MD_fpmath));
1158+
}
11431159
}
11441160

1145-
if (!UI)
1146-
return Changed;
1161+
if (SinCalls.empty() || CosCalls.empty())
1162+
return false;
1163+
1164+
B.setFastMathFlags(FMF);
1165+
B.setDefaultFPMathTag(FPMath);
1166+
DILocation *DbgLoc = DILocation::getMergedLocations(MergeDbgLocs);
1167+
B.SetCurrentDebugLocation(DbgLoc);
1168+
1169+
auto [Sin, Cos, SinCos] = insertSinCos(CArgVal, FMF, B, FSinCos);
1170+
1171+
auto replaceTrigInsts = [](ArrayRef<CallInst *> Calls, Value *Res) {
1172+
for (CallInst *C : Calls)
1173+
C->replaceAllUsesWith(Res);
1174+
1175+
// Leave the other dead instructions to avoid clobbering iterators.
1176+
};
11471177

1148-
CallInst *Sin = isSin ? CI : UI;
1149-
CallInst *Cos = isSin ? UI : CI;
1150-
return insertSinCos(Sin, Cos, B, fInfo) || Changed;
1178+
replaceTrigInsts(SinCalls, Sin);
1179+
replaceTrigInsts(CosCalls, Cos);
1180+
replaceTrigInsts(SinCosCalls, SinCos);
1181+
1182+
// It's safe to delete the original now.
1183+
CI->eraseFromParent();
1184+
return true;
11511185
}
11521186

11531187
bool AMDGPULibCalls::evaluateScalarMathFunc(const FuncInfo &FInfo,

llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-sincos.defined.ll

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ define void @sincos_f32(float %x, ptr addrspace(1) nocapture writeonly %sin_out,
110110
; CHECK-NEXT: [[TMP1:%.*]] = call contract float @_Z6sincosfPU3AS0f(float [[X]], ptr [[TMP0]])
111111
; CHECK-NEXT: [[TMP2:%.*]] = load float, ptr addrspace(5) [[__SINCOS_]], align 4
112112
; CHECK-NEXT: store float [[TMP1]], ptr addrspace(1) [[SIN_OUT]], align 4
113+
; CHECK-NEXT: [[CALL1:%.*]] = tail call contract float @_Z3cosf(float [[X]])
113114
; CHECK-NEXT: store float [[TMP2]], ptr addrspace(1) [[COS_OUT]], align 4
114115
; CHECK-NEXT: ret void
115116
;
@@ -130,6 +131,7 @@ define void @sincos_f32_value_is_same_constantfp(ptr addrspace(1) nocapture writ
130131
; CHECK-NEXT: [[TMP1:%.*]] = call contract float @_Z6sincosfPU3AS0f(float 4.200000e+01, ptr [[TMP0]])
131132
; CHECK-NEXT: [[TMP2:%.*]] = load float, ptr addrspace(5) [[__SINCOS_]], align 4
132133
; CHECK-NEXT: store float [[TMP1]], ptr addrspace(1) [[SIN_OUT]], align 4
134+
; CHECK-NEXT: [[CALL1:%.*]] = tail call contract float @_Z3cosf(float 4.200000e+01)
133135
; CHECK-NEXT: store float [[TMP2]], ptr addrspace(1) [[COS_OUT]], align 4
134136
; CHECK-NEXT: ret void
135137
;
@@ -159,6 +161,7 @@ define void @sincos_v2f32(<2 x float> %x, ptr addrspace(1) nocapture writeonly %
159161
; CHECK-NEXT: [[TMP1:%.*]] = call contract <2 x float> @_Z6sincosDv2_fPU3AS0S_(<2 x float> [[X]], ptr [[TMP0]])
160162
; CHECK-NEXT: [[TMP2:%.*]] = load <2 x float>, ptr addrspace(5) [[__SINCOS_]], align 8
161163
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr addrspace(1) [[SIN_OUT]], align 8
164+
; CHECK-NEXT: [[CALL1:%.*]] = tail call contract <2 x float> @_Z3cosDv2_f(<2 x float> [[X]])
162165
; CHECK-NEXT: store <2 x float> [[TMP2]], ptr addrspace(1) [[COS_OUT]], align 8
163166
; CHECK-NEXT: ret void
164167
;

llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-sincos.defined.nobuiltin.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ define void @sincos_f32(float %x, ptr addrspace(1) nocapture writeonly %sin_out,
6060
; CHECK-NEXT: [[TMP0:%.*]] = call contract float @_Z6sincosfPU3AS0f(float [[X]], ptr [[__SINCOS_]])
6161
; CHECK-NEXT: [[TMP1:%.*]] = load float, ptr [[__SINCOS_]], align 4
6262
; CHECK-NEXT: store float [[TMP0]], ptr addrspace(1) [[SIN_OUT]], align 4
63+
; CHECK-NEXT: [[CALL1:%.*]] = tail call contract float @_Z3cosf(float [[X]])
6364
; CHECK-NEXT: store float [[TMP1]], ptr addrspace(1) [[COS_OUT]], align 4
6465
; CHECK-NEXT: ret void
6566
;

0 commit comments

Comments
 (0)