Skip to content

[mlir][math] Fix polynomial math.asin approximation #101247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 31, 2024

Conversation

rsuderman
Copy link
Contributor

The polynomial approximation for asin is only good between [-9/16, 9/16]. Values beyond that range must be remapped to achieve good numeric results. This is done by the equation below:

arcsin(x) = PI/2 - arcsin(sqrt(1.0 - x*x))

The polynomial approximation for asin is only good between [-9/16,
9/16]. Values beyond that range must be remapped to achieve good numeric
results. This is done by the equation below:

`arcsin(x) = PI/2 - arcsin(sqrt(1.0 - x*x))`
@llvmbot
Copy link
Member

llvmbot commented Jul 30, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-math

Author: Rob Suderman (rsuderman)

Changes

The polynomial approximation for asin is only good between [-9/16, 9/16]. Values beyond that range must be remapped to achieve good numeric results. This is done by the equation below:

arcsin(x) = PI/2 - arcsin(sqrt(1.0 - x*x))


Full diff: https://github.com/llvm/llvm-project/pull/101247.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+34-3)
  • (modified) mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir (+4)
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index f4fae68da63b3..d6fe22158d002 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -861,7 +861,34 @@ AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
     return builder.create<arith::MulFOp>(a, b);
   };
 
-  Value s = mul(operand, operand);
+  auto sub = [&](Value a, Value b) -> Value {
+    return builder.create<arith::SubFOp>(a, b);
+  };
+
+  auto abs = [&](Value a) -> Value { return builder.create<math::AbsFOp>(a); };
+
+  auto sqrt = [&](Value a) -> Value { return builder.create<math::SqrtOp>(a); };
+
+  auto scopy = [&](Value a, Value b) -> Value {
+    return builder.create<math::CopySignOp>(a, b);
+  };
+
+  auto sel = [&](Value a, Value b, Value c) -> Value {
+    return builder.create<arith::SelectOp>(a, b, c);
+  };
+
+  Value abso = abs(operand);
+  Value aa = mul(operand, operand);
+  Value opp = sqrt(sub(bcast(floatCst(builder, 1.0, elementType)), aa));
+
+  Value gt =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, aa,
+                                    bcast(floatCst(builder, 0.5, elementType)));
+
+  Value x = sel(gt, opp, abso);
+
+  // Asin(x) approximation for x = [-9/16, 9/16]:
+  Value s = mul(x, x);
   Value q = mul(s, s);
   Value r = bcast(floatCst(builder, 5.5579749017470502e-2, elementType));
   Value t = bcast(floatCst(builder, -6.2027913464120114e-2, elementType));
@@ -878,8 +905,12 @@ AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
   t = fma(t, q, bcast(floatCst(builder, 7.4999999991367292e-2, elementType)));
   r = fma(r, s, t);
   r = fma(r, s, bcast(floatCst(builder, 1.6666666666670193e-1, elementType)));
-  t = mul(operand, s);
-  r = fma(r, t, operand);
+  t = mul(x, s);
+  r = fma(r, t, x);
+
+  Value rsub = sub(bcast(floatCst(builder, 1.57079632679, elementType)), r);
+  r = sel(gt, rsub, r);
+  r = scopy(r, operand);
 
   rewriter.replaceOp(op, r);
   return success();
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index 370c5baa0adef..b8861198d596b 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -493,6 +493,10 @@ func.func @asin() {
   %cst3 = arith.constant -0.25 : f32
   call @asin_f32(%cst3) : (f32) -> ()
 
+  // CHECK: -1.1197
+  %cst4 = arith.constant -0.90 : f32
+  call @asin_f32(%cst4) : (f32) -> ()
+
   // CHECK: 0.25268, 0.384397, 0.597406
   %vec_x = arith.constant dense<[0.25, 0.375, 0.5625]> : vector<3xf32>
   call @asin_3xf32(%vec_x) : (vector<3xf32>) -> ()

@rsuderman rsuderman requested a review from pashu123 July 30, 2024 22:31
Copy link
Member

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@rsuderman rsuderman merged commit a3fb301 into llvm:main Jul 31, 2024
10 checks passed
@rsuderman rsuderman deleted the asin_numerics_rebase branch July 31, 2024 17:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants