@@ -630,11 +630,146 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
630
630
return success ();
631
631
}
632
632
633
+ // ----------------------------------------------------------------------------//
634
+ // Sin and Cos approximation.
635
+ // ----------------------------------------------------------------------------//
636
+
637
+ namespace {
638
+
639
+ template <bool isSine, typename OpTy>
640
+ struct SinAndCosApproximation : public OpRewritePattern <OpTy> {
641
+ public:
642
+ using OpRewritePattern<OpTy>::OpRewritePattern;
643
+
644
+ LogicalResult matchAndRewrite (OpTy op, PatternRewriter &rewriter) const final ;
645
+ };
646
+ } // namespace
647
+
648
+ #define TWO_OVER_PI \
649
+ 0 .6366197723675813430755350534900574481378385829618257949906693762L
650
+ #define PI_OVER_2 \
651
+ 1 .5707963267948966192313216916397514420985846996875529104874722961L
652
+
653
+ // Approximates sin(x) or cos(x) by finding the best approximation polynomial in
654
+ // the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the
655
+ // reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y).
656
+ template <bool isSine, typename OpTy>
657
+ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
658
+ OpTy op, PatternRewriter &rewriter) const {
659
+ static_assert (
660
+ llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
661
+ " SinAndCosApproximation pattern expects math::SinOp or math::CosOp" );
662
+ auto width = vectorWidth (op.operand ().getType (), isF32);
663
+ if (!width.hasValue ())
664
+ return rewriter.notifyMatchFailure (op, " unsupported operand type" );
665
+
666
+ ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
667
+ auto bcast = [&](Value value) -> Value {
668
+ return broadcast (builder, value, *width);
669
+ };
670
+ auto mul = [&](Value a, Value b) -> Value {
671
+ return builder.create <MulFOp>(a, b);
672
+ };
673
+ auto sub = [&](Value a, Value b) -> Value {
674
+ return builder.create <SubFOp>(a, b);
675
+ };
676
+ auto floor = [&](Value a) { return builder.create <FloorFOp>(a); };
677
+
678
+ auto i32Vec = broadcast (builder.getI32Type (), *width);
679
+ auto fPToSingedInteger = [&](Value a) -> Value {
680
+ return builder.create <FPToSIOp>(a, i32Vec);
681
+ };
682
+
683
+ auto modulo4 = [&](Value a) -> Value {
684
+ return builder.create <AndOp>(a, bcast (i32Cst (builder, 3 )));
685
+ };
686
+
687
+ auto isEqualTo = [&](Value a, Value b) -> Value {
688
+ return builder.create <CmpIOp>(CmpIPredicate::eq, a, b);
689
+ };
690
+
691
+ auto isGreaterThan = [&](Value a, Value b) -> Value {
692
+ return builder.create <CmpIOp>(CmpIPredicate::sgt, a, b);
693
+ };
694
+
695
+ auto select = [&](Value cond, Value t, Value f) -> Value {
696
+ return builder.create <SelectOp>(cond, t, f);
697
+ };
698
+
699
+ auto fmla = [&](Value a, Value b, Value c) {
700
+ return builder.create <FmaFOp>(a, b, c);
701
+ };
702
+
703
+ auto bitwiseOr = [&](Value a, Value b) { return builder.create <OrOp>(a, b); };
704
+
705
+ Value twoOverPi = bcast (f32Cst (builder, TWO_OVER_PI));
706
+ Value piOverTwo = bcast (f32Cst (builder, PI_OVER_2));
707
+
708
+ Value x = op.operand ();
709
+
710
+ Value k = floor (mul (x, twoOverPi));
711
+
712
+ Value y = sub (x, mul (k, piOverTwo));
713
+
714
+ Value cstOne = bcast (f32Cst (builder, 1.0 ));
715
+ Value cstNegativeOne = bcast (f32Cst (builder, -1.0 ));
716
+
717
+ Value cstSC2 = bcast (f32Cst (builder, -0 .16666667163372039794921875f ));
718
+ Value cstSC4 = bcast (f32Cst (builder, 8 .333347737789154052734375e-3f ));
719
+ Value cstSC6 = bcast (f32Cst (builder, -1 .9842604524455964565277099609375e-4f ));
720
+ Value cstSC8 =
721
+ bcast (f32Cst (builder, 2 .760012648650445044040679931640625e-6f ));
722
+ Value cstSC10 =
723
+ bcast (f32Cst (builder, -2 .50293279435709337121807038784027099609375e-8f ));
724
+
725
+ Value cstCC2 = bcast (f32Cst (builder, -0 .5f ));
726
+ Value cstCC4 = bcast (f32Cst (builder, 4 .166664183139801025390625e-2f ));
727
+ Value cstCC6 = bcast (f32Cst (builder, -1 .388833043165504932403564453125e-3f ));
728
+ Value cstCC8 = bcast (f32Cst (builder, 2 .47562347794882953166961669921875e-5f ));
729
+ Value cstCC10 =
730
+ bcast (f32Cst (builder, -2 .59630184018533327616751194000244140625e-7f ));
731
+
732
+ Value kMod4 = modulo4 (fPToSingedInteger (k));
733
+
734
+ Value kR0 = isEqualTo (kMod4 , bcast (i32Cst (builder, 0 )));
735
+ Value kR1 = isEqualTo (kMod4 , bcast (i32Cst (builder, 1 )));
736
+ Value kR2 = isEqualTo (kMod4 , bcast (i32Cst (builder, 2 )));
737
+ Value kR3 = isEqualTo (kMod4 , bcast (i32Cst (builder, 3 )));
738
+
739
+ Value sinuseCos = isSine ? bitwiseOr (kR1 , kR3 ) : bitwiseOr (kR0 , kR2 );
740
+ Value negativeRange = isSine ? isGreaterThan (kMod4 , bcast (i32Cst (builder, 1 )))
741
+ : bitwiseOr (kR1 , kR2 );
742
+
743
+ Value y2 = mul (y, y);
744
+
745
+ Value base = select (sinuseCos, cstOne, y);
746
+ Value cstC2 = select (sinuseCos, cstCC2, cstSC2);
747
+ Value cstC4 = select (sinuseCos, cstCC4, cstSC4);
748
+ Value cstC6 = select (sinuseCos, cstCC6, cstSC6);
749
+ Value cstC8 = select (sinuseCos, cstCC8, cstSC8);
750
+ Value cstC10 = select (sinuseCos, cstCC10, cstSC10);
751
+
752
+ Value v1 = fmla (y2, cstC10, cstC8);
753
+ Value v2 = fmla (y2, v1, cstC6);
754
+ Value v3 = fmla (y2, v2, cstC4);
755
+ Value v4 = fmla (y2, v3, cstC2);
756
+ Value v5 = fmla (y2, v4, cstOne);
757
+ Value v6 = mul (base, v5);
758
+
759
+ Value approximation = select (negativeRange, mul (cstNegativeOne, v6), v6);
760
+
761
+ rewriter.replaceOp (op, approximation);
762
+
763
+ return success ();
764
+ }
765
+
633
766
// ----------------------------------------------------------------------------//
634
767
635
768
void mlir::populateMathPolynomialApproximationPatterns (
636
769
RewritePatternSet &patterns) {
637
770
patterns.add <TanhApproximation, LogApproximation, Log2Approximation,
638
- Log1pApproximation, ExpApproximation, ExpM1Approximation>(
771
+ Log1pApproximation, ExpApproximation, ExpM1Approximation,
772
+ SinAndCosApproximation<true , math::SinOp>,
773
+ SinAndCosApproximation<false , math::CosOp>>(
639
774
patterns.getContext ());
640
775
}
0 commit comments