Skip to content

[SimplifyLibCalls] Merge sqrt into the power of exp #79146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ class LibCallSimplifier {
Value *optimizeFMinFMax(CallInst *CI, IRBuilderBase &B);
Value *optimizeLog(CallInst *CI, IRBuilderBase &B);
Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B);
Value *mergeSqrtToExp(CallInst *CI, IRBuilderBase &B);
Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
Value *optimizeTan(CallInst *CI, IRBuilderBase &B);
// Wrapper for all floating point library call optimizations
Expand Down
67 changes: 67 additions & 0 deletions llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2538,6 +2538,70 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) {
return Ret;
}

// sqrt(exp(X)) -> exp(X * 0.5)
Value *LibCallSimplifier::mergeSqrtToExp(CallInst *CI, IRBuilderBase &B) {
if (!CI->hasAllowReassoc())
return nullptr;

Function *SqrtFn = CI->getCalledFunction();
CallInst *Arg = dyn_cast<CallInst>(CI->getArgOperand(0));
if (!Arg || !Arg->hasAllowReassoc() || !Arg->hasOneUse())
return nullptr;
Intrinsic::ID ArgID = Arg->getIntrinsicID();
LibFunc ArgLb = NotLibFunc;
TLI->getLibFunc(*Arg, ArgLb);

LibFunc SqrtLb, ExpLb, Exp2Lb, Exp10Lb;

if (TLI->getLibFunc(SqrtFn->getName(), SqrtLb))
switch (SqrtLb) {
case LibFunc_sqrtf:
ExpLb = LibFunc_expf;
Exp2Lb = LibFunc_exp2f;
Exp10Lb = LibFunc_exp10f;
break;
case LibFunc_sqrt:
ExpLb = LibFunc_exp;
Exp2Lb = LibFunc_exp2;
Exp10Lb = LibFunc_exp10;
break;
case LibFunc_sqrtl:
ExpLb = LibFunc_expl;
Exp2Lb = LibFunc_exp2l;
Exp10Lb = LibFunc_exp10l;
break;
default:
return nullptr;
}
else if (SqrtFn->getIntrinsicID() == Intrinsic::sqrt) {
Comment on lines +2556 to +2576
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we not have a better way of handling intrinsic-or-libcall?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find any other approaches to this. All code in this file seems to be written in the same way.

if (CI->getType()->getScalarType()->isFloatTy()) {
ExpLb = LibFunc_expf;
Exp2Lb = LibFunc_exp2f;
Exp10Lb = LibFunc_exp10f;
} else if (CI->getType()->getScalarType()->isDoubleTy()) {
ExpLb = LibFunc_exp;
Exp2Lb = LibFunc_exp2;
Exp10Lb = LibFunc_exp10;
} else
return nullptr;
} else
return nullptr;

if (ArgLb != ExpLb && ArgLb != Exp2Lb && ArgLb != Exp10Lb &&
ArgID != Intrinsic::exp && ArgID != Intrinsic::exp2)
return nullptr;

IRBuilderBase::InsertPointGuard Guard(B);
B.SetInsertPoint(Arg);
auto *ExpOperand = Arg->getOperand(0);
auto *FMul =
B.CreateFMulFMF(ExpOperand, ConstantFP::get(ExpOperand->getType(), 0.5),
CI, "merged.sqrt");

Arg->setOperand(0, FMul);
return Arg;
}

Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
Module *M = CI->getModule();
Function *Callee = CI->getCalledFunction();
Expand All @@ -2550,6 +2614,9 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
Callee->getIntrinsicID() == Intrinsic::sqrt))
Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);

if (Value *Opt = mergeSqrtToExp(CI, B))
return Opt;

if (!CI->isFast())
return Ret;

Expand Down
120 changes: 120 additions & 0 deletions llvm/test/Transforms/InstCombine/sqrt.ll
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,127 @@ define float @sqrt_call_fabs_f32(float %x) {
ret float %sqrt
}

define double @sqrt_exp(double %x) {
; CHECK-LABEL: @sqrt_exp(
; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc double [[X:%.*]], 5.000000e-01
; CHECK-NEXT: [[E:%.*]] = call reassoc double @llvm.exp.f64(double [[MERGED_SQRT]])
; CHECK-NEXT: ret double [[E]]
;
%e = call reassoc double @llvm.exp.f64(double %x)
%res = call reassoc double @llvm.sqrt.f64(double %e)
ret double %res
}

define double @sqrt_exp_2(double %x) {
; CHECK-LABEL: @sqrt_exp_2(
; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc double [[X:%.*]], 5.000000e-01
; CHECK-NEXT: [[E:%.*]] = call reassoc double @exp(double [[MERGED_SQRT]])
; CHECK-NEXT: ret double [[E]]
;
%e = call reassoc double @exp(double %x)
%res = call reassoc double @sqrt(double %e)
ret double %res
}

define double @sqrt_exp2(double %x) {
; CHECK-LABEL: @sqrt_exp2(
; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc double [[X:%.*]], 5.000000e-01
; CHECK-NEXT: [[E:%.*]] = call reassoc double @exp2(double [[MERGED_SQRT]])
; CHECK-NEXT: ret double [[E]]
;
%e = call reassoc double @exp2(double %x)
%res = call reassoc double @sqrt(double %e)
ret double %res
}

define double @sqrt_exp10(double %x) {
; CHECK-LABEL: @sqrt_exp10(
; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc double [[X:%.*]], 5.000000e-01
; CHECK-NEXT: [[E:%.*]] = call reassoc double @exp10(double [[MERGED_SQRT]])
; CHECK-NEXT: ret double [[E]]
;
%e = call reassoc double @exp10(double %x)
%res = call reassoc double @sqrt(double %e)
ret double %res
}

; Negative test
define double @sqrt_exp_nofast_1(double %x) {
; CHECK-LABEL: @sqrt_exp_nofast_1(
; CHECK-NEXT: [[E:%.*]] = call double @llvm.exp.f64(double [[X:%.*]])
; CHECK-NEXT: [[RES:%.*]] = call reassoc double @llvm.sqrt.f64(double [[E]])
; CHECK-NEXT: ret double [[RES]]
;
%e = call double @llvm.exp.f64(double %x)
%res = call reassoc double @llvm.sqrt.f64(double %e)
ret double %res
}

; Negative test
define double @sqrt_exp_nofast_2(double %x) {
; CHECK-LABEL: @sqrt_exp_nofast_2(
; CHECK-NEXT: [[E:%.*]] = call reassoc double @llvm.exp.f64(double [[X:%.*]])
; CHECK-NEXT: [[RES:%.*]] = call double @llvm.sqrt.f64(double [[E]])
; CHECK-NEXT: ret double [[RES]]
;
%e = call reassoc double @llvm.exp.f64(double %x)
%res = call double @llvm.sqrt.f64(double %e)
ret double %res
}

define double @sqrt_exp_merge_constant(double %x, double %y) {
; CHECK-LABEL: @sqrt_exp_merge_constant(
; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc nsz double [[X:%.*]], 5.000000e+00
; CHECK-NEXT: [[E:%.*]] = call reassoc double @llvm.exp.f64(double [[MERGED_SQRT]])
; CHECK-NEXT: ret double [[E]]
;
%mul = fmul reassoc nsz double %x, 10.0
%e = call reassoc double @llvm.exp.f64(double %mul)
%res = call reassoc nsz double @llvm.sqrt.f64(double %e)
ret double %res
}

define double @sqrt_exp_intr_and_libcall(double %x) {
; CHECK-LABEL: @sqrt_exp_intr_and_libcall(
; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc double [[X:%.*]], 5.000000e-01
; CHECK-NEXT: [[E:%.*]] = call reassoc double @exp(double [[MERGED_SQRT]])
; CHECK-NEXT: ret double [[E]]
;
%e = call reassoc double @exp(double %x)
%res = call reassoc double @llvm.sqrt.f64(double %e)
ret double %res
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should reduce test flags. Also, can you add the tests with libcall exp + intrinsic sqrt and intrinsic exp + libcall sqrt? We shouldn't introduce new libcalls from intrinsics

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've slightly simplified tests and added libcall + intrinsic tests. Fast-flags weren't modified. I'll adjust them when we decide on the correct set of flags that controls transformation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed fast flag to reassoc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the use of reassoc for this purpose, but it is consistent with existing practice, so I can't object to it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jcranmer-intel, in your oppinion, what is the correct set of flags? Since you said that it is consistent with the existing practice, I won't change them in this PR. But we may start a discussion at discourse and systematically change all places.

define double @sqrt_exp_intr_and_libcall_2(double %x) {
; CHECK-LABEL: @sqrt_exp_intr_and_libcall_2(
; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc double [[X:%.*]], 5.000000e-01
; CHECK-NEXT: [[E:%.*]] = call reassoc double @llvm.exp.f64(double [[MERGED_SQRT]])
; CHECK-NEXT: ret double [[E]]
;
%e = call reassoc double @llvm.exp.f64(double %x)
%res = call reassoc double @sqrt(double %e)
ret double %res
}

define <2 x float> @sqrt_exp_vec(<2 x float> %x) {
; CHECK-LABEL: @sqrt_exp_vec(
; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc <2 x float> [[X:%.*]], <float 5.000000e-01, float 5.000000e-01>
; CHECK-NEXT: [[E:%.*]] = call reassoc <2 x float> @llvm.exp.v2f32(<2 x float> [[MERGED_SQRT]])
; CHECK-NEXT: ret <2 x float> [[E]]
;
%e = call reassoc <2 x float> @llvm.exp.v2f32(<2 x float> %x)
%res = call reassoc <2 x float> @llvm.sqrt.v2f32(<2 x float> %e)
ret <2 x float> %res
}

declare i32 @foo(double)
declare double @sqrt(double) readnone
declare float @sqrtf(float)
declare float @llvm.fabs.f32(float)
declare double @llvm.exp.f64(double)
declare double @llvm.sqrt.f64(double)
declare double @exp(double)
declare double @exp2(double)
declare double @exp10(double)
declare <2 x float> @llvm.exp.v2f32(<2 x float>)
declare <2 x float> @llvm.sqrt.v2f32(<2 x float>)