Skip to content

Commit 1901f44

Browse files
authored
[InstCombine] Handle more even/odd math functions (#81324)
At the moment this PR adds support only for `erf` function. Fixes #77220.
1 parent 00c0638 commit 1901f44

File tree

8 files changed

+135
-46
lines changed

8 files changed

+135
-46
lines changed

llvm/include/llvm/Analysis/TargetLibraryInfo.def

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,6 +1069,21 @@ TLI_DEFINE_ENUM_INTERNAL(ctermid)
10691069
TLI_DEFINE_STRING_INTERNAL("ctermid")
10701070
TLI_DEFINE_SIG_INTERNAL(Ptr, Ptr)
10711071

1072+
/// double erf(double x);
1073+
TLI_DEFINE_ENUM_INTERNAL(erf)
1074+
TLI_DEFINE_STRING_INTERNAL("erf")
1075+
TLI_DEFINE_SIG_INTERNAL(Dbl, Dbl)
1076+
1077+
/// float erff(float x);
1078+
TLI_DEFINE_ENUM_INTERNAL(erff)
1079+
TLI_DEFINE_STRING_INTERNAL("erff")
1080+
TLI_DEFINE_SIG_INTERNAL(Flt, Flt)
1081+
1082+
/// long double erfl(long double x);
1083+
TLI_DEFINE_ENUM_INTERNAL(erfl)
1084+
TLI_DEFINE_STRING_INTERNAL("erfl")
1085+
TLI_DEFINE_SIG_INTERNAL(LDbl, LDbl)
1086+
10721087
/// int execl(const char *path, const char *arg, ...);
10731088
TLI_DEFINE_ENUM_INTERNAL(execl)
10741089
TLI_DEFINE_STRING_INTERNAL("execl")

llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ class LibCallSimplifier {
204204
Value *mergeSqrtToExp(CallInst *CI, IRBuilderBase &B);
205205
Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
206206
Value *optimizeTrigInversionPairs(CallInst *CI, IRBuilderBase &B);
207+
Value *optimizeSymmetric(CallInst *CI, LibFunc Func, IRBuilderBase &B);
207208
// Wrapper for all floating point library call optimizations
208209
Value *optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func,
209210
IRBuilderBase &B);

llvm/lib/Analysis/TargetLibraryInfo.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,9 @@ static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T,
808808
TLI.setUnavailable(LibFunc_cabs);
809809
TLI.setUnavailable(LibFunc_cabsf);
810810
TLI.setUnavailable(LibFunc_cabsl);
811+
TLI.setUnavailable(LibFunc_erf);
812+
TLI.setUnavailable(LibFunc_erff);
813+
TLI.setUnavailable(LibFunc_erfl);
811814
TLI.setUnavailable(LibFunc_ffs);
812815
TLI.setUnavailable(LibFunc_flockfile);
813816
TLI.setUnavailable(LibFunc_fseeko);

llvm/lib/Transforms/Utils/BuildLibCalls.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,6 +1137,9 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F,
11371137
case LibFunc_cosl:
11381138
case LibFunc_cospi:
11391139
case LibFunc_cospif:
1140+
case LibFunc_erf:
1141+
case LibFunc_erff:
1142+
case LibFunc_erfl:
11401143
case LibFunc_exp:
11411144
case LibFunc_expf:
11421145
case LibFunc_expl:

llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp

Lines changed: 58 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1908,49 +1908,6 @@ Value *LibCallSimplifier::optimizeCAbs(CallInst *CI, IRBuilderBase &B) {
19081908
*CI, B.CreateCall(FSqrt, B.CreateFAdd(RealReal, ImagImag), "cabs"));
19091909
}
19101910

1911-
static Value *optimizeTrigReflections(CallInst *Call, LibFunc Func,
1912-
IRBuilderBase &B) {
1913-
if (!isa<FPMathOperator>(Call))
1914-
return nullptr;
1915-
1916-
IRBuilderBase::FastMathFlagGuard Guard(B);
1917-
B.setFastMathFlags(Call->getFastMathFlags());
1918-
1919-
// TODO: Can this be shared to also handle LLVM intrinsics?
1920-
Value *X;
1921-
switch (Func) {
1922-
case LibFunc_sin:
1923-
case LibFunc_sinf:
1924-
case LibFunc_sinl:
1925-
case LibFunc_tan:
1926-
case LibFunc_tanf:
1927-
case LibFunc_tanl:
1928-
// sin(-X) --> -sin(X)
1929-
// tan(-X) --> -tan(X)
1930-
if (match(Call->getArgOperand(0), m_OneUse(m_FNeg(m_Value(X)))))
1931-
return B.CreateFNeg(
1932-
copyFlags(*Call, B.CreateCall(Call->getCalledFunction(), X)));
1933-
break;
1934-
case LibFunc_cos:
1935-
case LibFunc_cosf:
1936-
case LibFunc_cosl: {
1937-
// cos(-x) --> cos(x)
1938-
// cos(fabs(x)) --> cos(x)
1939-
// cos(copysign(x, y)) --> cos(x)
1940-
Value *Sign;
1941-
Value *Src = Call->getArgOperand(0);
1942-
if (match(Src, m_FNeg(m_Value(X))) || match(Src, m_FAbs(m_Value(X))) ||
1943-
match(Src, m_CopySign(m_Value(X), m_Value(Sign))))
1944-
return copyFlags(*Call,
1945-
B.CreateCall(Call->getCalledFunction(), X, "cos"));
1946-
break;
1947-
}
1948-
default:
1949-
break;
1950-
}
1951-
return nullptr;
1952-
}
1953-
19541911
// Return a properly extended integer (DstWidth bits wide) if the operation is
19551912
// an itofp.
19561913
static Value *getIntToFPVal(Value *I2F, IRBuilderBase &B, unsigned DstWidth) {
@@ -2797,6 +2754,63 @@ static bool insertSinCosCall(IRBuilderBase &B, Function *OrigCallee, Value *Arg,
27972754
return true;
27982755
}
27992756

2757+
static Value *optimizeSymmetricCall(CallInst *CI, bool IsEven,
2758+
IRBuilderBase &B) {
2759+
Value *X;
2760+
Value *Src = CI->getArgOperand(0);
2761+
2762+
if (match(Src, m_OneUse(m_FNeg(m_Value(X))))) {
2763+
IRBuilderBase::FastMathFlagGuard Guard(B);
2764+
B.setFastMathFlags(CI->getFastMathFlags());
2765+
2766+
auto *CallInst = copyFlags(*CI, B.CreateCall(CI->getCalledFunction(), {X}));
2767+
if (IsEven) {
2768+
// Even function: f(-x) = f(x)
2769+
return CallInst;
2770+
}
2771+
// Odd function: f(-x) = -f(x)
2772+
return B.CreateFNeg(CallInst);
2773+
}
2774+
2775+
// Even function: f(abs(x)) = f(x), f(copysign(x, y)) = f(x)
2776+
if (IsEven && (match(Src, m_FAbs(m_Value(X))) ||
2777+
match(Src, m_CopySign(m_Value(X), m_Value())))) {
2778+
IRBuilderBase::FastMathFlagGuard Guard(B);
2779+
B.setFastMathFlags(CI->getFastMathFlags());
2780+
2781+
auto *CallInst = copyFlags(*CI, B.CreateCall(CI->getCalledFunction(), {X}));
2782+
return CallInst;
2783+
}
2784+
2785+
return nullptr;
2786+
}
2787+
2788+
Value *LibCallSimplifier::optimizeSymmetric(CallInst *CI, LibFunc Func,
2789+
IRBuilderBase &B) {
2790+
switch (Func) {
2791+
case LibFunc_cos:
2792+
case LibFunc_cosf:
2793+
case LibFunc_cosl:
2794+
return optimizeSymmetricCall(CI, /*IsEven*/ true, B);
2795+
2796+
case LibFunc_sin:
2797+
case LibFunc_sinf:
2798+
case LibFunc_sinl:
2799+
2800+
case LibFunc_tan:
2801+
case LibFunc_tanf:
2802+
case LibFunc_tanl:
2803+
2804+
case LibFunc_erf:
2805+
case LibFunc_erff:
2806+
case LibFunc_erfl:
2807+
return optimizeSymmetricCall(CI, /*IsEven*/ false, B);
2808+
2809+
default:
2810+
return nullptr;
2811+
}
2812+
}
2813+
28002814
Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B) {
28012815
// Make sure the prototype is as expected, otherwise the rest of the
28022816
// function is probably invalid and likely to abort.
@@ -3678,7 +3692,7 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
36783692
if (CI->isStrictFP())
36793693
return nullptr;
36803694

3681-
if (Value *V = optimizeTrigReflections(CI, Func, Builder))
3695+
if (Value *V = optimizeSymmetric(CI, Func, Builder))
36823696
return V;
36833697

36843698
switch (Func) {
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
2+
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
3+
4+
declare double @erf(double)
5+
declare double @cos(double)
6+
declare double @fabs(double)
7+
8+
declare void @use(double) nounwind
9+
10+
; Check odd parity: -erf(-x) == erf(x)
11+
define double @test_erf(double %x) {
12+
; CHECK-LABEL: define double @test_erf(
13+
; CHECK-SAME: double [[X:%.*]]) {
14+
; CHECK-NEXT: [[RES:%.*]] = tail call reassoc double @erf(double [[X]])
15+
; CHECK-NEXT: ret double [[RES]]
16+
;
17+
%neg_x = fneg double %x
18+
%res = tail call reassoc double @erf(double %neg_x)
19+
%neg_res = fneg double %res
20+
ret double %neg_res
21+
}
22+
23+
; Check even parity: cos(fabs(x)) == cos(x)
24+
define double @test_cos_fabs(double %x) {
25+
; CHECK-LABEL: define double @test_cos_fabs(
26+
; CHECK-SAME: double [[X:%.*]]) {
27+
; CHECK-NEXT: [[RES:%.*]] = tail call reassoc double @cos(double [[X]])
28+
; CHECK-NEXT: ret double [[RES]]
29+
;
30+
%fabs_res = call double @fabs(double %x)
31+
%res = tail call reassoc double @cos(double %fabs_res)
32+
ret double %res
33+
}
34+
35+
; Do nothing in case of multi-use
36+
define double @test_erf_multi_use(double %x) {
37+
; CHECK-LABEL: define double @test_erf_multi_use(
38+
; CHECK-SAME: double [[X:%.*]]) {
39+
; CHECK-NEXT: [[NEG_X:%.*]] = fneg double [[X]]
40+
; CHECK-NEXT: call void @use(double [[NEG_X]])
41+
; CHECK-NEXT: [[RES:%.*]] = call double @erf(double [[NEG_X]])
42+
; CHECK-NEXT: [[NEG_RES:%.*]] = fneg double [[RES]]
43+
; CHECK-NEXT: ret double [[NEG_RES]]
44+
;
45+
%neg_x = fneg double %x
46+
call void @use(double %neg_x)
47+
%res = call double @erf(double %neg_x)
48+
%neg_res = fneg double %res
49+
ret double %neg_res
50+
}

llvm/test/tools/llvm-tli-checker/ps4-tli-check.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@
4747
## the exact count first; the two directives should add up to that.
4848
## Yes, this means additions to TLI will fail this test, but the argument
4949
## to -COUNT can't be an expression.
50-
# AVAIL: TLI knows 476 symbols, 243 available
50+
# AVAIL: TLI knows 479 symbols, 243 available
5151
# AVAIL-COUNT-243: {{^}} available
5252
# AVAIL-NOT: {{^}} available
53-
# UNAVAIL-COUNT-233: not available
53+
# UNAVAIL-COUNT-236: not available
5454
# UNAVAIL-NOT: not available
5555

5656
## This is a large file so it's worth telling lit to stop here.

llvm/unittests/Analysis/TargetLibraryInfoTest.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,9 @@ TEST_F(TargetLibraryInfoTest, ValidProto) {
264264
"declare double @pow(double, double)\n"
265265
"declare float @powf(float, float)\n"
266266
"declare x86_fp80 @powl(x86_fp80, x86_fp80)\n"
267+
"declare double @erf(double)\n"
268+
"declare float @erff(float)\n"
269+
"declare x86_fp80 @erfl(x86_fp80)\n"
267270
"declare i32 @printf(i8*, ...)\n"
268271
"declare i32 @putc(i32, %struct*)\n"
269272
"declare i32 @putc_unlocked(i32, %struct*)\n"

0 commit comments

Comments
 (0)