Skip to content

Commit 7208569

Browse files
authored
[mlir][math] Add Polynomial Approximation for acos, asin op (#90962)
Adds the Polynomial Approximation for math.acos and math.asin op. Also, it adds integration tests. The Approximation has been borrowed from https://stackoverflow.com/a/42683455 I added this script: https://gist.github.com/pashu123/cd3e682b21a64ac306f650fb842a422b to test 50 values between -1 and 1. The results are https://gist.github.com/pashu123/8acb233bd045bacabfa8c992d4040465. It's well within the bounds.
1 parent 486695d commit 7208569

File tree

2 files changed

+234
-6
lines changed

2 files changed

+234
-6
lines changed

mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp

Lines changed: 154 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,153 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
821821
return success();
822822
}
823823

824+
//----------------------------------------------------------------------------//
825+
// Asin approximation.
826+
//----------------------------------------------------------------------------//
827+
828+
// Approximates asin(x).
829+
// This approximation is based on the following stackoverflow post:
830+
// https://stackoverflow.com/a/42683455
831+
namespace {
832+
struct AsinPolynomialApproximation : public OpRewritePattern<math::AsinOp> {
833+
public:
834+
using OpRewritePattern::OpRewritePattern;
835+
836+
LogicalResult matchAndRewrite(math::AsinOp op,
837+
PatternRewriter &rewriter) const final;
838+
};
839+
} // namespace
840+
LogicalResult
841+
AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
842+
PatternRewriter &rewriter) const {
843+
Value operand = op.getOperand();
844+
Type elementType = getElementTypeOrSelf(operand);
845+
846+
if (!(elementType.isF32() || elementType.isF16()))
847+
return rewriter.notifyMatchFailure(op,
848+
"only f32 and f16 type is supported.");
849+
VectorShape shape = vectorShape(operand);
850+
851+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
852+
auto bcast = [&](Value value) -> Value {
853+
return broadcast(builder, value, shape);
854+
};
855+
856+
auto fma = [&](Value a, Value b, Value c) -> Value {
857+
return builder.create<math::FmaOp>(a, b, c);
858+
};
859+
860+
auto mul = [&](Value a, Value b) -> Value {
861+
return builder.create<arith::MulFOp>(a, b);
862+
};
863+
864+
Value s = mul(operand, operand);
865+
Value q = mul(s, s);
866+
Value r = bcast(floatCst(builder, 5.5579749017470502e-2, elementType));
867+
Value t = bcast(floatCst(builder, -6.2027913464120114e-2, elementType));
868+
869+
r = fma(r, q, bcast(floatCst(builder, 5.4224464349245036e-2, elementType)));
870+
t = fma(t, q, bcast(floatCst(builder, -1.1326992890324464e-2, elementType)));
871+
r = fma(r, q, bcast(floatCst(builder, 1.5268872539397656e-2, elementType)));
872+
t = fma(t, q, bcast(floatCst(builder, 1.0493798473372081e-2, elementType)));
873+
r = fma(r, q, bcast(floatCst(builder, 1.4106045900607047e-2, elementType)));
874+
t = fma(t, q, bcast(floatCst(builder, 1.7339776384962050e-2, elementType)));
875+
r = fma(r, q, bcast(floatCst(builder, 2.2372961589651054e-2, elementType)));
876+
t = fma(t, q, bcast(floatCst(builder, 3.0381912707941005e-2, elementType)));
877+
r = fma(r, q, bcast(floatCst(builder, 4.4642857881094775e-2, elementType)));
878+
t = fma(t, q, bcast(floatCst(builder, 7.4999999991367292e-2, elementType)));
879+
r = fma(r, s, t);
880+
r = fma(r, s, bcast(floatCst(builder, 1.6666666666670193e-1, elementType)));
881+
t = mul(operand, s);
882+
r = fma(r, t, operand);
883+
884+
rewriter.replaceOp(op, r);
885+
return success();
886+
}
887+
888+
//----------------------------------------------------------------------------//
889+
// Acos approximation.
890+
//----------------------------------------------------------------------------//
891+
892+
// Approximates acos(x).
893+
// This approximation is based on the following stackoverflow post:
894+
// https://stackoverflow.com/a/42683455
895+
namespace {
896+
struct AcosPolynomialApproximation : public OpRewritePattern<math::AcosOp> {
897+
public:
898+
using OpRewritePattern::OpRewritePattern;
899+
900+
LogicalResult matchAndRewrite(math::AcosOp op,
901+
PatternRewriter &rewriter) const final;
902+
};
903+
} // namespace
904+
LogicalResult
905+
AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
906+
PatternRewriter &rewriter) const {
907+
Value operand = op.getOperand();
908+
Type elementType = getElementTypeOrSelf(operand);
909+
910+
if (!(elementType.isF32() || elementType.isF16()))
911+
return rewriter.notifyMatchFailure(op,
912+
"only f32 and f16 type is supported.");
913+
VectorShape shape = vectorShape(operand);
914+
915+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
916+
auto bcast = [&](Value value) -> Value {
917+
return broadcast(builder, value, shape);
918+
};
919+
920+
auto fma = [&](Value a, Value b, Value c) -> Value {
921+
return builder.create<math::FmaOp>(a, b, c);
922+
};
923+
924+
auto mul = [&](Value a, Value b) -> Value {
925+
return builder.create<arith::MulFOp>(a, b);
926+
};
927+
928+
Value negOperand = builder.create<arith::NegFOp>(operand);
929+
Value zero = bcast(floatCst(builder, 0.0, elementType));
930+
Value half = bcast(floatCst(builder, 0.5, elementType));
931+
Value negOne = bcast(floatCst(builder, -1.0, elementType));
932+
Value selR =
933+
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, zero);
934+
Value r = builder.create<arith::SelectOp>(selR, negOperand, operand);
935+
Value chkConst = bcast(floatCst(builder, -0.5625, elementType));
936+
Value firstPred =
937+
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, r, chkConst);
938+
939+
Value trueVal =
940+
fma(bcast(floatCst(builder, 9.3282184640716537e-1, elementType)),
941+
bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
942+
builder.create<math::AsinOp>(r));
943+
944+
Value falseVal = builder.create<math::SqrtOp>(fma(half, r, half));
945+
falseVal = builder.create<math::AsinOp>(falseVal);
946+
falseVal = mul(bcast(floatCst(builder, 2.0, elementType)), falseVal);
947+
948+
r = builder.create<arith::SelectOp>(firstPred, trueVal, falseVal);
949+
950+
// Check whether the operand lies in between [-1.0, 0.0).
951+
Value greaterThanNegOne =
952+
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, negOne);
953+
954+
Value lessThanZero =
955+
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
956+
957+
Value betweenNegOneZero =
958+
builder.create<arith::AndIOp>(greaterThanNegOne, lessThanZero);
959+
960+
trueVal = fma(bcast(floatCst(builder, 1.8656436928143307e+0, elementType)),
961+
bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
962+
builder.create<arith::NegFOp>(r));
963+
964+
Value finalVal =
965+
builder.create<arith::SelectOp>(betweenNegOneZero, trueVal, r);
966+
967+
rewriter.replaceOp(op, finalVal);
968+
return success();
969+
}
970+
824971
//----------------------------------------------------------------------------//
825972
// Erf approximation.
826973
//----------------------------------------------------------------------------//
@@ -1505,12 +1652,13 @@ void mlir::populateMathPolynomialApproximationPatterns(
15051652
ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
15061653
patterns.getContext());
15071654

1508-
patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
1509-
LogApproximation, Log2Approximation, Log1pApproximation,
1510-
ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1511-
CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
1512-
SinAndCosApproximation<false, math::CosOp>>(
1513-
patterns.getContext());
1655+
patterns
1656+
.add<AtanApproximation, Atan2Approximation, TanhApproximation,
1657+
LogApproximation, Log2Approximation, Log1pApproximation,
1658+
ErfPolynomialApproximation, AsinPolynomialApproximation,
1659+
AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1660+
CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
1661+
SinAndCosApproximation<false, math::CosOp>>(patterns.getContext());
15141662
if (options.enableAvx2) {
15151663
patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
15161664
patterns.getContext());

mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,84 @@ func.func @cos() {
461461
return
462462
}
463463

464+
// -------------------------------------------------------------------------- //
465+
// Asin.
466+
// -------------------------------------------------------------------------- //
467+
func.func @asin_f32(%a : f32) {
468+
%r = math.asin %a : f32
469+
vector.print %r : f32
470+
return
471+
}
472+
473+
func.func @asin_3xf32(%a : vector<3xf32>) {
474+
%r = math.asin %a : vector<3xf32>
475+
vector.print %r : vector<3xf32>
476+
return
477+
}
478+
479+
func.func @asin() {
480+
// CHECK: 0
481+
%zero = arith.constant 0.0 : f32
482+
call @asin_f32(%zero) : (f32) -> ()
483+
484+
// CHECK: -0.597406
485+
%cst1 = arith.constant -0.5625 : f32
486+
call @asin_f32(%cst1) : (f32) -> ()
487+
488+
// CHECK: -0.384397
489+
%cst2 = arith.constant -0.375 : f32
490+
call @asin_f32(%cst2) : (f32) -> ()
491+
492+
// CHECK: -0.25268
493+
%cst3 = arith.constant -0.25 : f32
494+
call @asin_f32(%cst3) : (f32) -> ()
495+
496+
// CHECK: 0.25268, 0.384397, 0.597406
497+
%vec_x = arith.constant dense<[0.25, 0.375, 0.5625]> : vector<3xf32>
498+
call @asin_3xf32(%vec_x) : (vector<3xf32>) -> ()
499+
500+
return
501+
}
502+
503+
// -------------------------------------------------------------------------- //
504+
// Acos.
505+
// -------------------------------------------------------------------------- //
506+
func.func @acos_f32(%a : f32) {
507+
%r = math.acos %a : f32
508+
vector.print %r : f32
509+
return
510+
}
511+
512+
func.func @acos_3xf32(%a : vector<3xf32>) {
513+
%r = math.acos %a : vector<3xf32>
514+
vector.print %r : vector<3xf32>
515+
return
516+
}
517+
518+
func.func @acos() {
519+
// CHECK: 1.5708
520+
%zero = arith.constant 0.0 : f32
521+
call @acos_f32(%zero) : (f32) -> ()
522+
523+
// CHECK: 2.1682
524+
%cst1 = arith.constant -0.5625 : f32
525+
call @acos_f32(%cst1) : (f32) -> ()
526+
527+
// CHECK: 1.95519
528+
%cst2 = arith.constant -0.375 : f32
529+
call @acos_f32(%cst2) : (f32) -> ()
530+
531+
// CHECK: 1.82348
532+
%cst3 = arith.constant -0.25 : f32
533+
call @acos_f32(%cst3) : (f32) -> ()
534+
535+
// CHECK: 1.31812, 1.1864, 0.97339
536+
%vec_x = arith.constant dense<[0.25, 0.375, 0.5625]> : vector<3xf32>
537+
call @acos_3xf32(%vec_x) : (vector<3xf32>) -> ()
538+
539+
return
540+
}
541+
464542
// -------------------------------------------------------------------------- //
465543
// Atan.
466544
// -------------------------------------------------------------------------- //
@@ -694,6 +772,8 @@ func.func @main() {
694772
call @expm1(): () -> ()
695773
call @sin(): () -> ()
696774
call @cos(): () -> ()
775+
call @asin(): () -> ()
776+
call @acos(): () -> ()
697777
call @atan() : () -> ()
698778
call @atan2() : () -> ()
699779
call @cbrt() : () -> ()

0 commit comments

Comments
 (0)