-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][math] Add Polynomial Approximation for acos, asin op #90962
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Adds the Polynomial Approximation for math.acos and math.asin op. Also, adds integration tests. The Approximation has been borrowed from https://stackoverflow.com/a/42683455
@llvm/pr-subscribers-mlir-math @llvm/pr-subscribers-mlir Author: Prashant Kumar (pashu123) ChangesAdds the Polynomial Approximation for math.acos and math.asin op. Also, adds integration tests. Full diff: https://github.com/llvm/llvm-project/pull/90962.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 428c1c37c4e8b5..f4fae68da63b3a 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -821,6 +821,153 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
return success();
}
+//----------------------------------------------------------------------------//
+// Asin approximation.
+//----------------------------------------------------------------------------//
+
+// Approximates asin(x).
+// This approximation is based on the following stackoverflow post:
+// https://stackoverflow.com/a/42683455
+namespace {
+struct AsinPolynomialApproximation : public OpRewritePattern<math::AsinOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(math::AsinOp op,
+ PatternRewriter &rewriter) const final;
+};
+} // namespace
+LogicalResult
+AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
+ PatternRewriter &rewriter) const {
+ Value operand = op.getOperand();
+ Type elementType = getElementTypeOrSelf(operand);
+
+ if (!(elementType.isF32() || elementType.isF16()))
+ return rewriter.notifyMatchFailure(op,
+ "only f32 and f16 type is supported.");
+ VectorShape shape = vectorShape(operand);
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ auto bcast = [&](Value value) -> Value {
+ return broadcast(builder, value, shape);
+ };
+
+ auto fma = [&](Value a, Value b, Value c) -> Value {
+ return builder.create<math::FmaOp>(a, b, c);
+ };
+
+ auto mul = [&](Value a, Value b) -> Value {
+ return builder.create<arith::MulFOp>(a, b);
+ };
+
+ Value s = mul(operand, operand);
+ Value q = mul(s, s);
+ Value r = bcast(floatCst(builder, 5.5579749017470502e-2, elementType));
+ Value t = bcast(floatCst(builder, -6.2027913464120114e-2, elementType));
+
+ r = fma(r, q, bcast(floatCst(builder, 5.4224464349245036e-2, elementType)));
+ t = fma(t, q, bcast(floatCst(builder, -1.1326992890324464e-2, elementType)));
+ r = fma(r, q, bcast(floatCst(builder, 1.5268872539397656e-2, elementType)));
+ t = fma(t, q, bcast(floatCst(builder, 1.0493798473372081e-2, elementType)));
+ r = fma(r, q, bcast(floatCst(builder, 1.4106045900607047e-2, elementType)));
+ t = fma(t, q, bcast(floatCst(builder, 1.7339776384962050e-2, elementType)));
+ r = fma(r, q, bcast(floatCst(builder, 2.2372961589651054e-2, elementType)));
+ t = fma(t, q, bcast(floatCst(builder, 3.0381912707941005e-2, elementType)));
+ r = fma(r, q, bcast(floatCst(builder, 4.4642857881094775e-2, elementType)));
+ t = fma(t, q, bcast(floatCst(builder, 7.4999999991367292e-2, elementType)));
+ r = fma(r, s, t);
+ r = fma(r, s, bcast(floatCst(builder, 1.6666666666670193e-1, elementType)));
+ t = mul(operand, s);
+ r = fma(r, t, operand);
+
+ rewriter.replaceOp(op, r);
+ return success();
+}
+
+//----------------------------------------------------------------------------//
+// Acos approximation.
+//----------------------------------------------------------------------------//
+
+// Approximates acos(x).
+// This approximation is based on the following stackoverflow post:
+// https://stackoverflow.com/a/42683455
+namespace {
+struct AcosPolynomialApproximation : public OpRewritePattern<math::AcosOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(math::AcosOp op,
+ PatternRewriter &rewriter) const final;
+};
+} // namespace
+LogicalResult
+AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
+ PatternRewriter &rewriter) const {
+ Value operand = op.getOperand();
+ Type elementType = getElementTypeOrSelf(operand);
+
+ if (!(elementType.isF32() || elementType.isF16()))
+ return rewriter.notifyMatchFailure(op,
+ "only f32 and f16 type is supported.");
+ VectorShape shape = vectorShape(operand);
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ auto bcast = [&](Value value) -> Value {
+ return broadcast(builder, value, shape);
+ };
+
+ auto fma = [&](Value a, Value b, Value c) -> Value {
+ return builder.create<math::FmaOp>(a, b, c);
+ };
+
+ auto mul = [&](Value a, Value b) -> Value {
+ return builder.create<arith::MulFOp>(a, b);
+ };
+
+ Value negOperand = builder.create<arith::NegFOp>(operand);
+ Value zero = bcast(floatCst(builder, 0.0, elementType));
+ Value half = bcast(floatCst(builder, 0.5, elementType));
+ Value negOne = bcast(floatCst(builder, -1.0, elementType));
+ Value selR =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, zero);
+ Value r = builder.create<arith::SelectOp>(selR, negOperand, operand);
+ Value chkConst = bcast(floatCst(builder, -0.5625, elementType));
+ Value firstPred =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, r, chkConst);
+
+ Value trueVal =
+ fma(bcast(floatCst(builder, 9.3282184640716537e-1, elementType)),
+ bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
+ builder.create<math::AsinOp>(r));
+
+ Value falseVal = builder.create<math::SqrtOp>(fma(half, r, half));
+ falseVal = builder.create<math::AsinOp>(falseVal);
+ falseVal = mul(bcast(floatCst(builder, 2.0, elementType)), falseVal);
+
+ r = builder.create<arith::SelectOp>(firstPred, trueVal, falseVal);
+
+ // Check whether the operand lies in between [-1.0, 0.0).
+ Value greaterThanNegOne =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, negOne);
+
+ Value lessThanZero =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
+
+ Value betweenNegOneZero =
+ builder.create<arith::AndIOp>(greaterThanNegOne, lessThanZero);
+
+ trueVal = fma(bcast(floatCst(builder, 1.8656436928143307e+0, elementType)),
+ bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
+ builder.create<arith::NegFOp>(r));
+
+ Value finalVal =
+ builder.create<arith::SelectOp>(betweenNegOneZero, trueVal, r);
+
+ rewriter.replaceOp(op, finalVal);
+ return success();
+}
+
//----------------------------------------------------------------------------//
// Erf approximation.
//----------------------------------------------------------------------------//
@@ -1505,12 +1652,13 @@ void mlir::populateMathPolynomialApproximationPatterns(
ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
patterns.getContext());
- patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
- LogApproximation, Log2Approximation, Log1pApproximation,
- ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
- CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
- SinAndCosApproximation<false, math::CosOp>>(
- patterns.getContext());
+ patterns
+ .add<AtanApproximation, Atan2Approximation, TanhApproximation,
+ LogApproximation, Log2Approximation, Log1pApproximation,
+ ErfPolynomialApproximation, AsinPolynomialApproximation,
+ AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
+ CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
+ SinAndCosApproximation<false, math::CosOp>>(patterns.getContext());
if (options.enableAvx2) {
patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
patterns.getContext());
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index d3b19be9ecaf8f..370c5baa0adef3 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -461,6 +461,84 @@ func.func @cos() {
return
}
+// -------------------------------------------------------------------------- //
+// Asin.
+// -------------------------------------------------------------------------- //
+func.func @asin_f32(%a : f32) {
+ %r = math.asin %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @asin_3xf32(%a : vector<3xf32>) {
+ %r = math.asin %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @asin() {
+ // CHECK: 0
+ %zero = arith.constant 0.0 : f32
+ call @asin_f32(%zero) : (f32) -> ()
+
+ // CHECK: -0.597406
+ %cst1 = arith.constant -0.5625 : f32
+ call @asin_f32(%cst1) : (f32) -> ()
+
+ // CHECK: -0.384397
+ %cst2 = arith.constant -0.375 : f32
+ call @asin_f32(%cst2) : (f32) -> ()
+
+ // CHECK: -0.25268
+ %cst3 = arith.constant -0.25 : f32
+ call @asin_f32(%cst3) : (f32) -> ()
+
+ // CHECK: 0.25268, 0.384397, 0.597406
+ %vec_x = arith.constant dense<[0.25, 0.375, 0.5625]> : vector<3xf32>
+ call @asin_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
+// -------------------------------------------------------------------------- //
+// Acos.
+// -------------------------------------------------------------------------- //
+func.func @acos_f32(%a : f32) {
+ %r = math.acos %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @acos_3xf32(%a : vector<3xf32>) {
+ %r = math.acos %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @acos() {
+ // CHECK: 1.5708
+ %zero = arith.constant 0.0 : f32
+ call @acos_f32(%zero) : (f32) -> ()
+
+ // CHECK: 2.1682
+ %cst1 = arith.constant -0.5625 : f32
+ call @acos_f32(%cst1) : (f32) -> ()
+
+ // CHECK: 1.95519
+ %cst2 = arith.constant -0.375 : f32
+ call @acos_f32(%cst2) : (f32) -> ()
+
+ // CHECK: 1.82348
+ %cst3 = arith.constant -0.25 : f32
+ call @acos_f32(%cst3) : (f32) -> ()
+
+ // CHECK: 1.31812, 1.1864, 0.97339
+ %vec_x = arith.constant dense<[0.25, 0.375, 0.5625]> : vector<3xf32>
+ call @acos_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
// -------------------------------------------------------------------------- //
// Atan.
// -------------------------------------------------------------------------- //
@@ -694,6 +772,8 @@ func.func @main() {
call @expm1(): () -> ()
call @sin(): () -> ()
call @cos(): () -> ()
+ call @asin(): () -> ()
+ call @acos(): () -> ()
call @atan() : () -> ()
call @atan2() : () -> ()
call @cbrt() : () -> ()
|
Value r = bcast(floatCst(builder, 5.5579749017470502e-2, elementType)); | ||
Value t = bcast(floatCst(builder, -6.2027913464120114e-2, elementType)); | ||
|
||
r = fma(r, q, bcast(floatCst(builder, 5.4224464349245036e-2, elementType))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The naming has been borrowed from the post https://stackoverflow.com/a/42683455 .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah.. I was going to suggest using Remez algorithm instead of Taylor series. However, aren't these parameters optimized for [-9/16, 9/16]
? What happens to other input values...?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this script: https://gist.github.com/pashu123/cd3e682b21a64ac306f650fb842a422b to test 50 values between -1 and 1. The results are https://gist.github.com/pashu123/8acb233bd045bacabfa8c992d4040465 It's well within the bounds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, can you add it to the PR description?
Value r = bcast(floatCst(builder, 5.5579749017470502e-2, elementType)); | ||
Value t = bcast(floatCst(builder, -6.2027913464120114e-2, elementType)); | ||
|
||
r = fma(r, q, bcast(floatCst(builder, 5.4224464349245036e-2, elementType))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah.. I was going to suggest using Remez algorithm instead of Taylor series. However, aren't these parameters optimized for [-9/16, 9/16]
? What happens to other input values...?
Value r = bcast(floatCst(builder, 5.5579749017470502e-2, elementType)); | ||
Value t = bcast(floatCst(builder, -6.2027913464120114e-2, elementType)); | ||
|
||
r = fma(r, q, bcast(floatCst(builder, 5.4224464349245036e-2, elementType))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, can you add it to the PR description?
Adds the Polynomial Approximation for math.acos and math.asin op. Also, it adds integration tests.
The Approximation has been borrowed from
https://stackoverflow.com/a/42683455
I added this script: https://gist.github.com/pashu123/cd3e682b21a64ac306f650fb842a422b to test 50 values between -1 and 1. The results are https://gist.github.com/pashu123/8acb233bd045bacabfa8c992d4040465. It's well within the bounds.