Skip to content

[InstCombine] Optimise x / sqrt(y / z) with fast-math pattern. #76737

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 9, 2024

Conversation

zjaffal
Copy link
Contributor

@zjaffal zjaffal commented Jan 2, 2024

Replace the pattern with
x * sqrt(z/y)

@llvmbot
Copy link
Member

llvmbot commented Jan 2, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Zain Jaffal (zjaffal)

Changes

Replace the pattern with
x * sqrt(z/y)


Full diff: https://github.com/llvm/llvm-project/pull/76737.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp (+28)
  • (added) llvm/test/Transforms/InstCombine/fdiv-sqrt.ll (+85)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index f0ea3d9fcad5df..172ce18d003aa8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1701,6 +1701,31 @@ static Instruction *foldFDivPowDivisor(BinaryOperator &I,
   return BinaryOperator::CreateFMulFMF(Op0, Pow, &I);
 }
 
+/// Convert div to mul if we have an sqrt divisor iff sqrt's operand is a fdiv
+/// instruction.
+static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
+                                        InstCombiner::BuilderTy &Builder) {
+  // X / sqrt(Y / Z) -->  X * sqrt(Z / Y)
+  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+  auto *II = dyn_cast<IntrinsicInst>(Op1);
+  if (!II || II->getIntrinsicID() != Intrinsic::sqrt || !II->hasOneUse() ||
+      !I.hasAllowReassoc() || !I.hasAllowReciprocal())
+    return nullptr;
+
+  Value *Y, *Z;
+  auto *DivOp = dyn_cast<Instruction>(II->getOperand(0));
+  if (!DivOp || !DivOp->hasOneUse() || !DivOp->hasAllowReassoc() ||
+      !I.hasAllowReciprocal())
+    return nullptr;
+  if (match(DivOp, m_FDiv(m_Value(Y), m_Value(Z)))) {
+    Value *SwapDiv = Builder.CreateFDivFMF(Z, Y, DivOp);
+    Value *NewSqrt = Builder.CreateIntrinsic(II->getIntrinsicID(),
+                                             II->getType(), {SwapDiv}, II);
+    return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I);
+  }
+  return nullptr;
+}
+
 Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
   Module *M = I.getModule();
 
@@ -1808,6 +1833,9 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
   if (Instruction *Mul = foldFDivPowDivisor(I, Builder))
     return Mul;
 
+  if (Instruction *Mul = foldFDivSqrtDivisor(I, Builder))
+    return Mul;
+
   // pow(X, Y) / X --> pow(X, Y-1)
   if (I.hasAllowReassoc() &&
       match(Op0, m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Specific(Op1),
diff --git a/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll b/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
new file mode 100644
index 00000000000000..3f41b0f24ae040
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
@@ -0,0 +1,85 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+declare double @llvm.sqrt.f64(double)
+
+define double @sqrt_div_fast(double %x, double %y, double %z) {
+; CHECK-LABEL: @sqrt_div_fast(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = fdiv fast double [[Z:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call fast double @llvm.sqrt.f64(double [[TMP0]])
+; CHECK-NEXT:    [[DIV1:%.*]] = fmul fast double [[TMP1]], [[X:%.*]]
+; CHECK-NEXT:    ret double [[DIV1]]
+;
+entry:
+  %div = fdiv fast double %y, %z
+  %sqrt = call fast double @llvm.sqrt.f64(double %div)
+  %div1 = fdiv fast double %x, %sqrt
+  ret double %div1
+}
+
+define double @sqrt_div(double %x, double %y, double %z) {
+; CHECK-LABEL: @sqrt_div(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = fdiv double [[Y:%.*]], [[Z:%.*]]
+; CHECK-NEXT:    [[SQRT:%.*]] = call double @llvm.sqrt.f64(double [[DIV]])
+; CHECK-NEXT:    [[DIV1:%.*]] = fdiv double [[X:%.*]], [[SQRT]]
+; CHECK-NEXT:    ret double [[DIV1]]
+;
+entry:
+  %div = fdiv double %y, %z
+  %sqrt = call double @llvm.sqrt.f64(double %div)
+  %div1 = fdiv double %x, %sqrt
+  ret double %div1
+}
+
+define double @sqrt_div_reassoc_arcp(double %x, double %y, double %z) {
+; CHECK-LABEL: @sqrt_div_reassoc_arcp(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = fdiv reassoc arcp double [[Z:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[TMP0]])
+; CHECK-NEXT:    [[DIV1:%.*]] = fmul reassoc arcp double [[TMP1]], [[X:%.*]]
+; CHECK-NEXT:    ret double [[DIV1]]
+;
+entry:
+  %div = fdiv reassoc arcp double %y, %z
+  %sqrt = call reassoc arcp double @llvm.sqrt.f64(double %div)
+  %div1 = fdiv reassoc arcp double %x, %sqrt
+  ret double %div1
+}
+
+declare void @use(double)
+define double @sqrt_div_fast_multiple_uses_1(double %x, double %y, double %z) {
+; CHECK-LABEL: @sqrt_div_fast_multiple_uses_1(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = fdiv fast double [[Y:%.*]], [[Z:%.*]]
+; CHECK-NEXT:    call void @use(double [[DIV]])
+; CHECK-NEXT:    [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[DIV]])
+; CHECK-NEXT:    [[DIV1:%.*]] = fdiv fast double [[X:%.*]], [[SQRT]]
+; CHECK-NEXT:    ret double [[DIV1]]
+;
+entry:
+  %div = fdiv fast double %y, %z
+  call void @use(double %div)
+  %sqrt = call fast double @llvm.sqrt.f64(double %div)
+  %div1 = fdiv fast double %x, %sqrt
+  ret double %div1
+}
+
+define double @sqrt_div_fast_multiple_uses_2(double %x, double %y, double %z) {
+; CHECK-LABEL: @sqrt_div_fast_multiple_uses_2(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = fdiv fast double [[Y:%.*]], [[Z:%.*]]
+; CHECK-NEXT:    [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[DIV]])
+; CHECK-NEXT:    call void @use(double [[SQRT]])
+; CHECK-NEXT:    [[DIV1:%.*]] = fdiv fast double [[X:%.*]], [[SQRT]]
+; CHECK-NEXT:    ret double [[DIV1]]
+;
+entry:
+  %div = fdiv fast double %y, %z
+  %sqrt = call fast double @llvm.sqrt.f64(double %div)
+  call void @use(double %sqrt)
+  %div1 = fdiv fast double %x, %sqrt
+  ret double %div1
+}
+

@zjaffal zjaffal requested review from arsenm and dtcxzyw January 2, 2024 17:16
@zjaffal zjaffal self-assigned this Jan 2, 2024
@zjaffal zjaffal linked an issue Jan 2, 2024 that may be closed by this pull request
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
auto *II = dyn_cast<IntrinsicInst>(Op1);
if (!II || II->getIntrinsicID() != Intrinsic::sqrt || !II->hasOneUse() ||
!I.hasAllowReassoc() || !I.hasAllowReciprocal())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need reassoc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my assumption was following the implementation of foldFDivPowDivisor or at least the div inside the sqrt should have that flag enabled?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using x.recip() as shorthand for (1.0 / x), arcp allows us to convert between a / b and a * b.recip() freely, and arguably (a / b).recip() to (b / a) (at the very least, gcc will do this with just -freciprocal-math).

We can convert x / sqrt(y / z) to x * sqrt(y / z).recip() with just arcp, and we can also convert x * sqrt((y / z).recip()) to x * sqrt(z / y) with just arcp as well, but the question is which flags are necessary to convert sqrt(a).recip() to sqrt(a.recip()). I think it stretches arcp too far to allow this kind of permutation of recip. reassoc isn't entirely the right flag either; what we probably want is a more generic "allow algebraic identity" flag, but we're already using reassoc for that purpose elsewhere, so we might as well use it here as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay so from my understanding then the pattern we are looking for is the following

%div = fdiv arcp double %y, %z
%sqrt = call reassoc double @llvm.sqrt.f64(double %div)
%div2 = fdiv arcp double %x, %sqrt

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite, you'd need the reassoc on the first fdiv as well, and at that point, it's probably best to just require that all the operations have reassoc and arcp.

@nikic nikic removed their request for review January 2, 2024 17:51
@dtcxzyw dtcxzyw requested a review from jcranmer-intel January 2, 2024 18:43
// X / sqrt(Y / Z) --> X * sqrt(Z / Y)
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
auto *II = dyn_cast<IntrinsicInst>(Op1);
if (!II || II->getIntrinsicID() != Intrinsic::sqrt || !II->hasOneUse() ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be cleaner to use PatternMatch. Something like

if (match(I, m_Sqrt(m_OneUse(m_FDiv(m_Value(Op0), m_Value(Op1))) {

}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree but then we won't be able to check if sqrt has one use or fdiv has the necessary flags

%div = fdiv fast double %y, %z
call void @use(double %div)
%sqrt = call fast double @llvm.sqrt.f64(double %div)
%div1 = fdiv fast double %x, %sqrt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use reduced set of flags, and have some cases where flags are missing from individual instructions

Copy link
Contributor Author

@zjaffal zjaffal Jan 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use reduced set of flags

Is it satisfactory to assume only reassoc arcp are enabled?

have some cases where flags are missing from individual instructions

Will do thanks for the suggestion

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think so

@zjaffal
Copy link
Contributor Author

zjaffal commented Jan 10, 2024

ping

1 similar comment
@zjaffal
Copy link
Contributor Author

zjaffal commented Jan 23, 2024

ping

; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -S -passes=instcombine < %s | FileCheck %s

declare double @llvm.sqrt.f64(double)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pre-submit baseline tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a separate PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, or directly push

Copy link

github-actions bot commented Feb 8, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@zjaffal
Copy link
Contributor Author

zjaffal commented Feb 8, 2024

@arsenm thanks for your help i will push the test commits directly to main and then rebase this branch and merge

@zjaffal zjaffal force-pushed the zjaffal/issue65608 branch from 9171f25 to ab33697 Compare February 8, 2024 13:42
@zjaffal zjaffal force-pushed the zjaffal/issue65608 branch from ab33697 to fca21cf Compare February 8, 2024 19:27
@zjaffal
Copy link
Contributor Author

zjaffal commented Feb 9, 2024

@arsenm i forced push the branch after cherry picking the tests to main now every check passes
shall I merge it

@zjaffal zjaffal merged commit bb5c389 into llvm:main Feb 9, 2024
@zjaffal zjaffal deleted the zjaffal/issue65608 branch February 9, 2024 17:24
mstorsjo added a commit that referenced this pull request Feb 10, 2024
#76737)"

This reverts commit bb5c389.

That commit caused failed asserts like this:

$ cat repro.c
float a, b;
double sqrt();
void c() { b = a / sqrt(a); }
$ clang -target x86_64-linux-gnu -c -O2 -ffast-math repro.c
clang: ../lib/IR/Instruction.cpp:522: bool llvm::Instruction::hasAllowReassoc() const: Assertion `isa<FPMathOperator>(this) && "getting fast-math flag on invalid op"' failed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

clang is suboptimal for x / sqrt(y / z) with fast-math
5 participants