Skip to content

Commit 1e5d6cf

Browse files
author
Hal Finkel
committed
Apply the InstCombine fptrunc sqrt optimization to llvm.sqrt
InstCombine, in visitFPTrunc, applies the following optimization to sqrt calls: (fptrunc (sqrt (fpext x))) -> (sqrtf x) but does not apply the same optimization to llvm.sqrt. This is a problem because, to enable vectorization, Clang generates llvm.sqrt instead of sqrt in fast-math mode, and because this optimization is being applied to sqrt and not applied to llvm.sqrt, sometimes the fast-math code is slower. This change makes InstCombine apply this optimization to llvm.sqrt as well. This fixes the specific problem in PR17758, although the same underlying issue (optimizations applied to libcalls are not applied to intrinsics) exists for other optimizations in SimplifyLibCalls. llvm-svn: 194935
1 parent 39162ac commit 1e5d6cf

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,9 +1262,14 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) {
12621262
}
12631263

12641264
// Fold (fptrunc (sqrt (fpext x))) -> (sqrtf x)
1265+
// Note that we restrict this transformation based on
1266+
// TLI->has(LibFunc::sqrtf), even for the sqrt intrinsic, because
1267+
// TLI->has(LibFunc::sqrtf) is sufficient to guarantee that the
1268+
// single-precision intrinsic can be expanded in the backend.
12651269
CallInst *Call = dyn_cast<CallInst>(CI.getOperand(0));
12661270
if (Call && Call->getCalledFunction() && TLI->has(LibFunc::sqrtf) &&
1267-
Call->getCalledFunction()->getName() == TLI->getName(LibFunc::sqrt) &&
1271+
(Call->getCalledFunction()->getName() == TLI->getName(LibFunc::sqrt) ||
1272+
Call->getCalledFunction()->getIntrinsicID() == Intrinsic::sqrt) &&
12681273
Call->getNumArgOperands() == 1 &&
12691274
Call->hasOneUse()) {
12701275
CastInst *Arg = dyn_cast<CastInst>(Call->getArgOperand(0));
@@ -1275,11 +1280,11 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) {
12751280
Arg->getOperand(0)->getType()->isFloatTy()) {
12761281
Function *Callee = Call->getCalledFunction();
12771282
Module *M = CI.getParent()->getParent()->getParent();
1278-
Constant *SqrtfFunc = M->getOrInsertFunction("sqrtf",
1279-
Callee->getAttributes(),
1280-
Builder->getFloatTy(),
1281-
Builder->getFloatTy(),
1282-
NULL);
1283+
Constant *SqrtfFunc = (Callee->getIntrinsicID() == Intrinsic::sqrt) ?
1284+
Intrinsic::getDeclaration(M, Intrinsic::sqrt, Builder->getFloatTy()) :
1285+
M->getOrInsertFunction("sqrtf", Callee->getAttributes(),
1286+
Builder->getFloatTy(), Builder->getFloatTy(),
1287+
NULL);
12831288
CallInst *ret = CallInst::Create(SqrtfFunc, Arg->getOperand(0),
12841289
"sqrtfcall");
12851290
ret->setAttributes(Callee->getAttributes());

llvm/test/Transforms/InstCombine/double-float-shrink-1.ll

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ define double @sin_test2(float %f) nounwind readnone {
263263
ret double %call
264264
; CHECK: call double @sin(double %conv)
265265
}
266+
266267
define float @sqrt_test(float %f) nounwind readnone {
267268
; CHECK: sqrt_test
268269
%conv = fpext float %f to double
@@ -272,6 +273,15 @@ define float @sqrt_test(float %f) nounwind readnone {
272273
; CHECK: call float @sqrtf(float %f)
273274
}
274275

276+
define float @sqrt_int_test(float %f) nounwind readnone {
277+
; CHECK: sqrt_int_test
278+
%conv = fpext float %f to double
279+
%call = call double @llvm.sqrt.f64(double %conv)
280+
%conv1 = fptrunc double %call to float
281+
ret float %conv1
282+
; CHECK: call float @llvm.sqrt.f32(float %f)
283+
}
284+
275285
define double @sqrt_test2(float %f) nounwind readnone {
276286
; CHECK: sqrt_test2
277287
%conv = fpext float %f to double
@@ -331,3 +341,6 @@ declare double @acos(double) nounwind readnone
331341
declare double @acosh(double) nounwind readnone
332342
declare double @asin(double) nounwind readnone
333343
declare double @asinh(double) nounwind readnone
344+
345+
declare double @llvm.sqrt.f64(double) nounwind readnone
346+

0 commit comments

Comments
 (0)