Skip to content

[mlir][math] Add Polynomial Approximation for acos, asin op #90962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 7, 2024

Conversation

pashu123
Copy link
Member

@pashu123 pashu123 commented May 3, 2024

Adds the Polynomial Approximation for math.acos and math.asin op. Also, it adds integration tests.
The Approximation has been borrowed from
https://stackoverflow.com/a/42683455
I added this script: https://gist.github.com/pashu123/cd3e682b21a64ac306f650fb842a422b to test 50 values between -1 and 1. The results are https://gist.github.com/pashu123/8acb233bd045bacabfa8c992d4040465. It's well within the bounds.

Adds the Polynomial Approximation for math.acos and math.asin op.
Also, adds integration tests.
The Approximation has been borrowed from
https://stackoverflow.com/a/42683455
@llvmbot
Copy link
Member

llvmbot commented May 3, 2024

@llvm/pr-subscribers-mlir-math

@llvm/pr-subscribers-mlir

Author: Prashant Kumar (pashu123)

Changes

Adds the Polynomial Approximation for math.acos and math.asin op. Also, adds integration tests.
The Approximation has been borrowed from
https://stackoverflow.com/a/42683455


Full diff: https://github.com/llvm/llvm-project/pull/90962.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+154-6)
  • (modified) mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir (+80)
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 428c1c37c4e8b5..f4fae68da63b3a 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -821,6 +821,153 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
   return success();
 }
 
+//----------------------------------------------------------------------------//
+// Asin approximation.
+//----------------------------------------------------------------------------//
+
+// Approximates asin(x).
+// This approximation is based on the following stackoverflow post:
+// https://stackoverflow.com/a/42683455
+namespace {
+struct AsinPolynomialApproximation : public OpRewritePattern<math::AsinOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(math::AsinOp op,
+                                PatternRewriter &rewriter) const final;
+};
+} // namespace
+LogicalResult
+AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
+                                             PatternRewriter &rewriter) const {
+  Value operand = op.getOperand();
+  Type elementType = getElementTypeOrSelf(operand);
+
+  if (!(elementType.isF32() || elementType.isF16()))
+    return rewriter.notifyMatchFailure(op,
+                                       "only f32 and f16 type is supported.");
+  VectorShape shape = vectorShape(operand);
+
+  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+  auto bcast = [&](Value value) -> Value {
+    return broadcast(builder, value, shape);
+  };
+
+  auto fma = [&](Value a, Value b, Value c) -> Value {
+    return builder.create<math::FmaOp>(a, b, c);
+  };
+
+  auto mul = [&](Value a, Value b) -> Value {
+    return builder.create<arith::MulFOp>(a, b);
+  };
+
+  Value s = mul(operand, operand);
+  Value q = mul(s, s);
+  Value r = bcast(floatCst(builder, 5.5579749017470502e-2, elementType));
+  Value t = bcast(floatCst(builder, -6.2027913464120114e-2, elementType));
+
+  r = fma(r, q, bcast(floatCst(builder, 5.4224464349245036e-2, elementType)));
+  t = fma(t, q, bcast(floatCst(builder, -1.1326992890324464e-2, elementType)));
+  r = fma(r, q, bcast(floatCst(builder, 1.5268872539397656e-2, elementType)));
+  t = fma(t, q, bcast(floatCst(builder, 1.0493798473372081e-2, elementType)));
+  r = fma(r, q, bcast(floatCst(builder, 1.4106045900607047e-2, elementType)));
+  t = fma(t, q, bcast(floatCst(builder, 1.7339776384962050e-2, elementType)));
+  r = fma(r, q, bcast(floatCst(builder, 2.2372961589651054e-2, elementType)));
+  t = fma(t, q, bcast(floatCst(builder, 3.0381912707941005e-2, elementType)));
+  r = fma(r, q, bcast(floatCst(builder, 4.4642857881094775e-2, elementType)));
+  t = fma(t, q, bcast(floatCst(builder, 7.4999999991367292e-2, elementType)));
+  r = fma(r, s, t);
+  r = fma(r, s, bcast(floatCst(builder, 1.6666666666670193e-1, elementType)));
+  t = mul(operand, s);
+  r = fma(r, t, operand);
+
+  rewriter.replaceOp(op, r);
+  return success();
+}
+
+//----------------------------------------------------------------------------//
+// Acos approximation.
+//----------------------------------------------------------------------------//
+
+// Approximates acos(x).
+// This approximation is based on the following stackoverflow post:
+// https://stackoverflow.com/a/42683455
+namespace {
+struct AcosPolynomialApproximation : public OpRewritePattern<math::AcosOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(math::AcosOp op,
+                                PatternRewriter &rewriter) const final;
+};
+} // namespace
+LogicalResult
+AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
+                                             PatternRewriter &rewriter) const {
+  Value operand = op.getOperand();
+  Type elementType = getElementTypeOrSelf(operand);
+
+  if (!(elementType.isF32() || elementType.isF16()))
+    return rewriter.notifyMatchFailure(op,
+                                       "only f32 and f16 type is supported.");
+  VectorShape shape = vectorShape(operand);
+
+  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+  auto bcast = [&](Value value) -> Value {
+    return broadcast(builder, value, shape);
+  };
+
+  auto fma = [&](Value a, Value b, Value c) -> Value {
+    return builder.create<math::FmaOp>(a, b, c);
+  };
+
+  auto mul = [&](Value a, Value b) -> Value {
+    return builder.create<arith::MulFOp>(a, b);
+  };
+
+  Value negOperand = builder.create<arith::NegFOp>(operand);
+  Value zero = bcast(floatCst(builder, 0.0, elementType));
+  Value half = bcast(floatCst(builder, 0.5, elementType));
+  Value negOne = bcast(floatCst(builder, -1.0, elementType));
+  Value selR =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, zero);
+  Value r = builder.create<arith::SelectOp>(selR, negOperand, operand);
+  Value chkConst = bcast(floatCst(builder, -0.5625, elementType));
+  Value firstPred =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, r, chkConst);
+
+  Value trueVal =
+      fma(bcast(floatCst(builder, 9.3282184640716537e-1, elementType)),
+          bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
+          builder.create<math::AsinOp>(r));
+
+  Value falseVal = builder.create<math::SqrtOp>(fma(half, r, half));
+  falseVal = builder.create<math::AsinOp>(falseVal);
+  falseVal = mul(bcast(floatCst(builder, 2.0, elementType)), falseVal);
+
+  r = builder.create<arith::SelectOp>(firstPred, trueVal, falseVal);
+
+  // Check whether the operand lies in between [-1.0, 0.0).
+  Value greaterThanNegOne =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, negOne);
+
+  Value lessThanZero =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
+
+  Value betweenNegOneZero =
+      builder.create<arith::AndIOp>(greaterThanNegOne, lessThanZero);
+
+  trueVal = fma(bcast(floatCst(builder, 1.8656436928143307e+0, elementType)),
+                bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
+                builder.create<arith::NegFOp>(r));
+
+  Value finalVal =
+      builder.create<arith::SelectOp>(betweenNegOneZero, trueVal, r);
+
+  rewriter.replaceOp(op, finalVal);
+  return success();
+}
+
 //----------------------------------------------------------------------------//
 // Erf approximation.
 //----------------------------------------------------------------------------//
@@ -1505,12 +1652,13 @@ void mlir::populateMathPolynomialApproximationPatterns(
            ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
           patterns.getContext());
 
-  patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
-               LogApproximation, Log2Approximation, Log1pApproximation,
-               ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
-               CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
-               SinAndCosApproximation<false, math::CosOp>>(
-      patterns.getContext());
+  patterns
+      .add<AtanApproximation, Atan2Approximation, TanhApproximation,
+           LogApproximation, Log2Approximation, Log1pApproximation,
+           ErfPolynomialApproximation, AsinPolynomialApproximation,
+           AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
+           CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
+           SinAndCosApproximation<false, math::CosOp>>(patterns.getContext());
   if (options.enableAvx2) {
     patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
         patterns.getContext());
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index d3b19be9ecaf8f..370c5baa0adef3 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -461,6 +461,84 @@ func.func @cos() {
   return
 }
 
+// -------------------------------------------------------------------------- //
+// Asin.
+// -------------------------------------------------------------------------- //
+func.func @asin_f32(%a : f32) {
+  %r = math.asin %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @asin_3xf32(%a : vector<3xf32>) {
+  %r = math.asin %a : vector<3xf32>
+  vector.print %r : vector<3xf32>
+  return
+}
+
+func.func @asin() {
+  // CHECK: 0
+  %zero = arith.constant 0.0 : f32
+  call @asin_f32(%zero) : (f32) -> ()
+
+  // CHECK: -0.597406
+  %cst1 = arith.constant -0.5625 : f32
+  call @asin_f32(%cst1) : (f32) -> ()
+
+  // CHECK: -0.384397
+  %cst2 = arith.constant -0.375 : f32
+  call @asin_f32(%cst2) : (f32) -> ()
+
+  // CHECK: -0.25268
+  %cst3 = arith.constant -0.25 : f32
+  call @asin_f32(%cst3) : (f32) -> ()
+
+  // CHECK: 0.25268, 0.384397, 0.597406
+  %vec_x = arith.constant dense<[0.25, 0.375, 0.5625]> : vector<3xf32>
+  call @asin_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+  return
+}
+
+// -------------------------------------------------------------------------- //
+// Acos.
+// -------------------------------------------------------------------------- //
+func.func @acos_f32(%a : f32) {
+  %r = math.acos %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @acos_3xf32(%a : vector<3xf32>) {
+  %r = math.acos %a : vector<3xf32>
+  vector.print %r : vector<3xf32>
+  return
+}
+
+func.func @acos() {
+  // CHECK: 1.5708
+  %zero = arith.constant 0.0 : f32
+  call @acos_f32(%zero) : (f32) -> ()
+
+  // CHECK: 2.1682
+  %cst1 = arith.constant -0.5625 : f32
+  call @acos_f32(%cst1) : (f32) -> ()
+
+  // CHECK: 1.95519
+  %cst2 = arith.constant -0.375 : f32
+  call @acos_f32(%cst2) : (f32) -> ()
+
+  // CHECK: 1.82348
+  %cst3 = arith.constant -0.25 : f32
+  call @acos_f32(%cst3) : (f32) -> ()
+
+  // CHECK: 1.31812, 1.1864, 0.97339
+  %vec_x = arith.constant dense<[0.25, 0.375, 0.5625]> : vector<3xf32>
+  call @acos_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+  return
+}
+
 // -------------------------------------------------------------------------- //
 // Atan.
 // -------------------------------------------------------------------------- //
@@ -694,6 +772,8 @@ func.func @main() {
   call @expm1(): () -> ()
   call @sin(): () -> ()
   call @cos(): () -> ()
+  call @asin(): () -> ()
+  call @acos(): () -> ()
   call @atan() : () -> ()
   call @atan2() : () -> ()
   call @cbrt() : () -> ()

@pashu123 pashu123 requested a review from rsuderman May 3, 2024 12:31
Value r = bcast(floatCst(builder, 5.5579749017470502e-2, elementType));
Value t = bcast(floatCst(builder, -6.2027913464120114e-2, elementType));

r = fma(r, q, bcast(floatCst(builder, 5.4224464349245036e-2, elementType)));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The naming has been borrowed from the post https://stackoverflow.com/a/42683455 .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah.. I was going to suggest using Remez algorithm instead of Taylor series. However, aren't these parameters optimized for [-9/16, 9/16]? What happens to other input values...?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this script: https://gist.github.com/pashu123/cd3e682b21a64ac306f650fb842a422b to test 50 values between -1 and 1. The results are https://gist.github.com/pashu123/8acb233bd045bacabfa8c992d4040465 It's well within the bounds.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, can you add it to the PR description?

@pashu123 pashu123 requested a review from hanhanW May 4, 2024 02:46
Value r = bcast(floatCst(builder, 5.5579749017470502e-2, elementType));
Value t = bcast(floatCst(builder, -6.2027913464120114e-2, elementType));

r = fma(r, q, bcast(floatCst(builder, 5.4224464349245036e-2, elementType)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah.. I was going to suggest using Remez algorithm instead of Taylor series. However, aren't these parameters optimized for [-9/16, 9/16]? What happens to other input values...?

Value r = bcast(floatCst(builder, 5.5579749017470502e-2, elementType));
Value t = bcast(floatCst(builder, -6.2027913464120114e-2, elementType));

r = fma(r, q, bcast(floatCst(builder, 5.4224464349245036e-2, elementType)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, can you add it to the PR description?

@pashu123 pashu123 merged commit 7208569 into llvm:main May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants