Skip to content

[InstCombine] Handle more even/odd math functions #81324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions llvm/include/llvm/Analysis/TargetLibraryInfo.def
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,21 @@ TLI_DEFINE_ENUM_INTERNAL(ctermid)
TLI_DEFINE_STRING_INTERNAL("ctermid")
TLI_DEFINE_SIG_INTERNAL(Ptr, Ptr)

/// double erf(double x);
TLI_DEFINE_ENUM_INTERNAL(erf)
TLI_DEFINE_STRING_INTERNAL("erf")
TLI_DEFINE_SIG_INTERNAL(Dbl, Dbl)

/// float erff(float x);
TLI_DEFINE_ENUM_INTERNAL(erff)
TLI_DEFINE_STRING_INTERNAL("erff")
TLI_DEFINE_SIG_INTERNAL(Flt, Flt)

/// long double erfl(long double x);
TLI_DEFINE_ENUM_INTERNAL(erfl)
TLI_DEFINE_STRING_INTERNAL("erfl")
TLI_DEFINE_SIG_INTERNAL(LDbl, LDbl)

/// int execl(const char *path, const char *arg, ...);
TLI_DEFINE_ENUM_INTERNAL(execl)
TLI_DEFINE_STRING_INTERNAL("execl")
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ class LibCallSimplifier {
Value *mergeSqrtToExp(CallInst *CI, IRBuilderBase &B);
Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
Value *optimizeTrigInversionPairs(CallInst *CI, IRBuilderBase &B);
Value *optimizeSymmetric(CallInst *CI, LibFunc Func, IRBuilderBase &B);
// Wrapper for all floating point library call optimizations
Value *optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func,
IRBuilderBase &B);
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Analysis/TargetLibraryInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,9 @@ static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T,
TLI.setUnavailable(LibFunc_cabs);
TLI.setUnavailable(LibFunc_cabsf);
TLI.setUnavailable(LibFunc_cabsl);
TLI.setUnavailable(LibFunc_erf);
TLI.setUnavailable(LibFunc_erff);
TLI.setUnavailable(LibFunc_erfl);
TLI.setUnavailable(LibFunc_ffs);
TLI.setUnavailable(LibFunc_flockfile);
TLI.setUnavailable(LibFunc_fseeko);
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Transforms/Utils/BuildLibCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,9 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F,
case LibFunc_cosl:
case LibFunc_cospi:
case LibFunc_cospif:
case LibFunc_erf:
case LibFunc_erff:
case LibFunc_erfl:
case LibFunc_exp:
case LibFunc_expf:
case LibFunc_expl:
Expand Down
102 changes: 58 additions & 44 deletions llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1908,49 +1908,6 @@ Value *LibCallSimplifier::optimizeCAbs(CallInst *CI, IRBuilderBase &B) {
*CI, B.CreateCall(FSqrt, B.CreateFAdd(RealReal, ImagImag), "cabs"));
}

static Value *optimizeTrigReflections(CallInst *Call, LibFunc Func,
IRBuilderBase &B) {
if (!isa<FPMathOperator>(Call))
return nullptr;

IRBuilderBase::FastMathFlagGuard Guard(B);
B.setFastMathFlags(Call->getFastMathFlags());

// TODO: Can this be shared to also handle LLVM intrinsics?
Value *X;
switch (Func) {
case LibFunc_sin:
case LibFunc_sinf:
case LibFunc_sinl:
case LibFunc_tan:
case LibFunc_tanf:
case LibFunc_tanl:
// sin(-X) --> -sin(X)
// tan(-X) --> -tan(X)
if (match(Call->getArgOperand(0), m_OneUse(m_FNeg(m_Value(X)))))
return B.CreateFNeg(
copyFlags(*Call, B.CreateCall(Call->getCalledFunction(), X)));
break;
case LibFunc_cos:
case LibFunc_cosf:
case LibFunc_cosl: {
// cos(-x) --> cos(x)
// cos(fabs(x)) --> cos(x)
// cos(copysign(x, y)) --> cos(x)
Value *Sign;
Value *Src = Call->getArgOperand(0);
if (match(Src, m_FNeg(m_Value(X))) || match(Src, m_FAbs(m_Value(X))) ||
match(Src, m_CopySign(m_Value(X), m_Value(Sign))))
return copyFlags(*Call,
B.CreateCall(Call->getCalledFunction(), X, "cos"));
break;
}
default:
break;
}
return nullptr;
}

// Return a properly extended integer (DstWidth bits wide) if the operation is
// an itofp.
static Value *getIntToFPVal(Value *I2F, IRBuilderBase &B, unsigned DstWidth) {
Expand Down Expand Up @@ -2797,6 +2754,63 @@ static bool insertSinCosCall(IRBuilderBase &B, Function *OrigCallee, Value *Arg,
return true;
}

static Value *optimizeSymmetricCall(CallInst *CI, bool IsEven,
IRBuilderBase &B) {
Value *X;
Value *Src = CI->getArgOperand(0);

if (match(Src, m_OneUse(m_FNeg(m_Value(X))))) {
IRBuilderBase::FastMathFlagGuard Guard(B);
B.setFastMathFlags(CI->getFastMathFlags());

auto *CallInst = copyFlags(*CI, B.CreateCall(CI->getCalledFunction(), {X}));
if (IsEven) {
// Even function: f(-x) = f(x)
return CallInst;
}
// Odd function: f(-x) = -f(x)
return B.CreateFNeg(CallInst);
}

// Even function: f(abs(x)) = f(x), f(copysign(x, y)) = f(x)
if (IsEven && (match(Src, m_FAbs(m_Value(X))) ||
match(Src, m_CopySign(m_Value(X), m_Value())))) {
IRBuilderBase::FastMathFlagGuard Guard(B);
B.setFastMathFlags(CI->getFastMathFlags());

auto *CallInst = copyFlags(*CI, B.CreateCall(CI->getCalledFunction(), {X}));
return CallInst;
}

return nullptr;
}

Value *LibCallSimplifier::optimizeSymmetric(CallInst *CI, LibFunc Func,
IRBuilderBase &B) {
switch (Func) {
case LibFunc_cos:
case LibFunc_cosf:
case LibFunc_cosl:
return optimizeSymmetricCall(CI, /*IsEven*/ true, B);

case LibFunc_sin:
case LibFunc_sinf:
case LibFunc_sinl:

case LibFunc_tan:
case LibFunc_tanf:
case LibFunc_tanl:

case LibFunc_erf:
case LibFunc_erff:
case LibFunc_erfl:
return optimizeSymmetricCall(CI, /*IsEven*/ false, B);

default:
return nullptr;
}
}

Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B) {
// Make sure the prototype is as expected, otherwise the rest of the
// function is probably invalid and likely to abort.
Expand Down Expand Up @@ -3678,7 +3692,7 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
if (CI->isStrictFP())
return nullptr;

if (Value *V = optimizeTrigReflections(CI, Func, Builder))
if (Value *V = optimizeSymmetric(CI, Func, Builder))
return V;

switch (Func) {
Expand Down
50 changes: 50 additions & 0 deletions llvm/test/Transforms/InstCombine/math-odd-even-parity.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
; RUN: opt < %s -passes=instcombine -S | FileCheck %s

declare double @erf(double)
declare double @cos(double)
declare double @fabs(double)

declare void @use(double) nounwind

; Check odd parity: -erf(-x) == erf(x)
define double @test_erf(double %x) {
; CHECK-LABEL: define double @test_erf(
; CHECK-SAME: double [[X:%.*]]) {
; CHECK-NEXT: [[RES:%.*]] = tail call reassoc double @erf(double [[X]])
; CHECK-NEXT: ret double [[RES]]
;
%neg_x = fneg double %x
%res = tail call reassoc double @erf(double %neg_x)
%neg_res = fneg double %res
ret double %neg_res
}

; Check even parity: cos(fabs(x)) == cos(x)
define double @test_cos_fabs(double %x) {
; CHECK-LABEL: define double @test_cos_fabs(
; CHECK-SAME: double [[X:%.*]]) {
; CHECK-NEXT: [[RES:%.*]] = tail call reassoc double @cos(double [[X]])
; CHECK-NEXT: ret double [[RES]]
;
%fabs_res = call double @fabs(double %x)
%res = tail call reassoc double @cos(double %fabs_res)
ret double %res
}

; Do nothing in case of multi-use
define double @test_erf_multi_use(double %x) {
; CHECK-LABEL: define double @test_erf_multi_use(
; CHECK-SAME: double [[X:%.*]]) {
; CHECK-NEXT: [[NEG_X:%.*]] = fneg double [[X]]
; CHECK-NEXT: call void @use(double [[NEG_X]])
; CHECK-NEXT: [[RES:%.*]] = call double @erf(double [[NEG_X]])
; CHECK-NEXT: [[NEG_RES:%.*]] = fneg double [[RES]]
; CHECK-NEXT: ret double [[NEG_RES]]
;
%neg_x = fneg double %x
call void @use(double %neg_x)
%res = call double @erf(double %neg_x)
%neg_res = fneg double %res
ret double %neg_res
}
4 changes: 2 additions & 2 deletions llvm/test/tools/llvm-tli-checker/ps4-tli-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@
## the exact count first; the two directives should add up to that.
## Yes, this means additions to TLI will fail this test, but the argument
## to -COUNT can't be an expression.
# AVAIL: TLI knows 476 symbols, 243 available
# AVAIL: TLI knows 479 symbols, 243 available
# AVAIL-COUNT-243: {{^}} available
# AVAIL-NOT: {{^}} available
# UNAVAIL-COUNT-233: not available
# UNAVAIL-COUNT-236: not available
# UNAVAIL-NOT: not available

## This is a large file so it's worth telling lit to stop here.
Expand Down
3 changes: 3 additions & 0 deletions llvm/unittests/Analysis/TargetLibraryInfoTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,9 @@ TEST_F(TargetLibraryInfoTest, ValidProto) {
"declare double @pow(double, double)\n"
"declare float @powf(float, float)\n"
"declare x86_fp80 @powl(x86_fp80, x86_fp80)\n"
"declare double @erf(double)\n"
"declare float @erff(float)\n"
"declare x86_fp80 @erfl(x86_fp80)\n"
"declare i32 @printf(i8*, ...)\n"
"declare i32 @putc(i32, %struct*)\n"
"declare i32 @putc_unlocked(i32, %struct*)\n"
Expand Down