@@ -821,6 +821,153 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
821
821
return success ();
822
822
}
823
823
824
+ // ----------------------------------------------------------------------------//
825
+ // Asin approximation.
826
+ // ----------------------------------------------------------------------------//
827
+
828
+ // Approximates asin(x).
829
+ // This approximation is based on the following stackoverflow post:
830
+ // https://stackoverflow.com/a/42683455
831
+ namespace {
832
+ struct AsinPolynomialApproximation : public OpRewritePattern <math::AsinOp> {
833
+ public:
834
+ using OpRewritePattern::OpRewritePattern;
835
+
836
+ LogicalResult matchAndRewrite (math::AsinOp op,
837
+ PatternRewriter &rewriter) const final ;
838
+ };
839
+ } // namespace
840
+ LogicalResult
841
+ AsinPolynomialApproximation::matchAndRewrite (math::AsinOp op,
842
+ PatternRewriter &rewriter) const {
843
+ Value operand = op.getOperand ();
844
+ Type elementType = getElementTypeOrSelf (operand);
845
+
846
+ if (!(elementType.isF32 () || elementType.isF16 ()))
847
+ return rewriter.notifyMatchFailure (op,
848
+ " only f32 and f16 type is supported." );
849
+ VectorShape shape = vectorShape (operand);
850
+
851
+ ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
852
+ auto bcast = [&](Value value) -> Value {
853
+ return broadcast (builder, value, shape);
854
+ };
855
+
856
+ auto fma = [&](Value a, Value b, Value c) -> Value {
857
+ return builder.create <math::FmaOp>(a, b, c);
858
+ };
859
+
860
+ auto mul = [&](Value a, Value b) -> Value {
861
+ return builder.create <arith::MulFOp>(a, b);
862
+ };
863
+
864
+ Value s = mul (operand, operand);
865
+ Value q = mul (s, s);
866
+ Value r = bcast (floatCst (builder, 5.5579749017470502e-2 , elementType));
867
+ Value t = bcast (floatCst (builder, -6.2027913464120114e-2 , elementType));
868
+
869
+ r = fma (r, q, bcast (floatCst (builder, 5.4224464349245036e-2 , elementType)));
870
+ t = fma (t, q, bcast (floatCst (builder, -1.1326992890324464e-2 , elementType)));
871
+ r = fma (r, q, bcast (floatCst (builder, 1.5268872539397656e-2 , elementType)));
872
+ t = fma (t, q, bcast (floatCst (builder, 1.0493798473372081e-2 , elementType)));
873
+ r = fma (r, q, bcast (floatCst (builder, 1.4106045900607047e-2 , elementType)));
874
+ t = fma (t, q, bcast (floatCst (builder, 1.7339776384962050e-2 , elementType)));
875
+ r = fma (r, q, bcast (floatCst (builder, 2.2372961589651054e-2 , elementType)));
876
+ t = fma (t, q, bcast (floatCst (builder, 3.0381912707941005e-2 , elementType)));
877
+ r = fma (r, q, bcast (floatCst (builder, 4.4642857881094775e-2 , elementType)));
878
+ t = fma (t, q, bcast (floatCst (builder, 7.4999999991367292e-2 , elementType)));
879
+ r = fma (r, s, t);
880
+ r = fma (r, s, bcast (floatCst (builder, 1.6666666666670193e-1 , elementType)));
881
+ t = mul (operand, s);
882
+ r = fma (r, t, operand);
883
+
884
+ rewriter.replaceOp (op, r);
885
+ return success ();
886
+ }
887
+
888
+ // ----------------------------------------------------------------------------//
889
+ // Acos approximation.
890
+ // ----------------------------------------------------------------------------//
891
+
892
+ // Approximates acos(x).
893
+ // This approximation is based on the following stackoverflow post:
894
+ // https://stackoverflow.com/a/42683455
895
+ namespace {
896
+ struct AcosPolynomialApproximation : public OpRewritePattern <math::AcosOp> {
897
+ public:
898
+ using OpRewritePattern::OpRewritePattern;
899
+
900
+ LogicalResult matchAndRewrite (math::AcosOp op,
901
+ PatternRewriter &rewriter) const final ;
902
+ };
903
+ } // namespace
904
+ LogicalResult
905
+ AcosPolynomialApproximation::matchAndRewrite (math::AcosOp op,
906
+ PatternRewriter &rewriter) const {
907
+ Value operand = op.getOperand ();
908
+ Type elementType = getElementTypeOrSelf (operand);
909
+
910
+ if (!(elementType.isF32 () || elementType.isF16 ()))
911
+ return rewriter.notifyMatchFailure (op,
912
+ " only f32 and f16 type is supported." );
913
+ VectorShape shape = vectorShape (operand);
914
+
915
+ ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
916
+ auto bcast = [&](Value value) -> Value {
917
+ return broadcast (builder, value, shape);
918
+ };
919
+
920
+ auto fma = [&](Value a, Value b, Value c) -> Value {
921
+ return builder.create <math::FmaOp>(a, b, c);
922
+ };
923
+
924
+ auto mul = [&](Value a, Value b) -> Value {
925
+ return builder.create <arith::MulFOp>(a, b);
926
+ };
927
+
928
+ Value negOperand = builder.create <arith::NegFOp>(operand);
929
+ Value zero = bcast (floatCst (builder, 0.0 , elementType));
930
+ Value half = bcast (floatCst (builder, 0.5 , elementType));
931
+ Value negOne = bcast (floatCst (builder, -1.0 , elementType));
932
+ Value selR =
933
+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, zero);
934
+ Value r = builder.create <arith::SelectOp>(selR, negOperand, operand);
935
+ Value chkConst = bcast (floatCst (builder, -0.5625 , elementType));
936
+ Value firstPred =
937
+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OGT, r, chkConst);
938
+
939
+ Value trueVal =
940
+ fma (bcast (floatCst (builder, 9.3282184640716537e-1 , elementType)),
941
+ bcast (floatCst (builder, 1.6839188885261840e+0 , elementType)),
942
+ builder.create <math::AsinOp>(r));
943
+
944
+ Value falseVal = builder.create <math::SqrtOp>(fma (half, r, half));
945
+ falseVal = builder.create <math::AsinOp>(falseVal);
946
+ falseVal = mul (bcast (floatCst (builder, 2.0 , elementType)), falseVal);
947
+
948
+ r = builder.create <arith::SelectOp>(firstPred, trueVal, falseVal);
949
+
950
+ // Check whether the operand lies in between [-1.0, 0.0).
951
+ Value greaterThanNegOne =
952
+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, negOne);
953
+
954
+ Value lessThanZero =
955
+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
956
+
957
+ Value betweenNegOneZero =
958
+ builder.create <arith::AndIOp>(greaterThanNegOne, lessThanZero);
959
+
960
+ trueVal = fma (bcast (floatCst (builder, 1.8656436928143307e+0 , elementType)),
961
+ bcast (floatCst (builder, 1.6839188885261840e+0 , elementType)),
962
+ builder.create <arith::NegFOp>(r));
963
+
964
+ Value finalVal =
965
+ builder.create <arith::SelectOp>(betweenNegOneZero, trueVal, r);
966
+
967
+ rewriter.replaceOp (op, finalVal);
968
+ return success ();
969
+ }
970
+
824
971
// ----------------------------------------------------------------------------//
825
972
// Erf approximation.
826
973
// ----------------------------------------------------------------------------//
@@ -1505,12 +1652,13 @@ void mlir::populateMathPolynomialApproximationPatterns(
1505
1652
ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
1506
1653
patterns.getContext ());
1507
1654
1508
- patterns.add <AtanApproximation, Atan2Approximation, TanhApproximation,
1509
- LogApproximation, Log2Approximation, Log1pApproximation,
1510
- ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1511
- CbrtApproximation, SinAndCosApproximation<true , math::SinOp>,
1512
- SinAndCosApproximation<false , math::CosOp>>(
1513
- patterns.getContext ());
1655
+ patterns
1656
+ .add <AtanApproximation, Atan2Approximation, TanhApproximation,
1657
+ LogApproximation, Log2Approximation, Log1pApproximation,
1658
+ ErfPolynomialApproximation, AsinPolynomialApproximation,
1659
+ AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1660
+ CbrtApproximation, SinAndCosApproximation<true , math::SinOp>,
1661
+ SinAndCosApproximation<false , math::CosOp>>(patterns.getContext ());
1514
1662
if (options.enableAvx2 ) {
1515
1663
patterns.add <RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
1516
1664
patterns.getContext ());
0 commit comments