Skip to content

Commit 7e2d672

Browse files
committed
Add polynomial approximation for trigonometric sine and cosine functions
The approximation relays on range reduced version y \in [0, pi/2]. An input x will have the property that sin(x) = sin(y), -sin(y), cos(y), -cos(y) depends on which quadrable x is in, where sin(y) and cos(y) are approximated with 5th degree polynomial (of x^2). As a result a single pattern can be used to compute approximation for both sine and cosine. Reviewed By: ezhulenev Differential Revision: https://reviews.llvm.org/D104582
1 parent 1244bca commit 7e2d672

File tree

2 files changed

+215
-1
lines changed

2 files changed

+215
-1
lines changed

mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,11 +630,146 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
630630
return success();
631631
}
632632

633+
//----------------------------------------------------------------------------//
634+
// Sin and Cos approximation.
635+
//----------------------------------------------------------------------------//
636+
637+
namespace {
638+
639+
template <bool isSine, typename OpTy>
640+
struct SinAndCosApproximation : public OpRewritePattern<OpTy> {
641+
public:
642+
using OpRewritePattern<OpTy>::OpRewritePattern;
643+
644+
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
645+
};
646+
} // namespace
647+
648+
#define TWO_OVER_PI \
649+
0.6366197723675813430755350534900574481378385829618257949906693762L
650+
#define PI_OVER_2 \
651+
1.5707963267948966192313216916397514420985846996875529104874722961L
652+
653+
// Approximates sin(x) or cos(x) by finding the best approximation polynomial in
654+
// the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the
655+
// reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y).
656+
template <bool isSine, typename OpTy>
657+
LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
658+
OpTy op, PatternRewriter &rewriter) const {
659+
static_assert(
660+
llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
661+
"SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
662+
auto width = vectorWidth(op.operand().getType(), isF32);
663+
if (!width.hasValue())
664+
return rewriter.notifyMatchFailure(op, "unsupported operand type");
665+
666+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
667+
auto bcast = [&](Value value) -> Value {
668+
return broadcast(builder, value, *width);
669+
};
670+
auto mul = [&](Value a, Value b) -> Value {
671+
return builder.create<MulFOp>(a, b);
672+
};
673+
auto sub = [&](Value a, Value b) -> Value {
674+
return builder.create<SubFOp>(a, b);
675+
};
676+
auto floor = [&](Value a) { return builder.create<FloorFOp>(a); };
677+
678+
auto i32Vec = broadcast(builder.getI32Type(), *width);
679+
auto fPToSingedInteger = [&](Value a) -> Value {
680+
return builder.create<FPToSIOp>(a, i32Vec);
681+
};
682+
683+
auto modulo4 = [&](Value a) -> Value {
684+
return builder.create<AndOp>(a, bcast(i32Cst(builder, 3)));
685+
};
686+
687+
auto isEqualTo = [&](Value a, Value b) -> Value {
688+
return builder.create<CmpIOp>(CmpIPredicate::eq, a, b);
689+
};
690+
691+
auto isGreaterThan = [&](Value a, Value b) -> Value {
692+
return builder.create<CmpIOp>(CmpIPredicate::sgt, a, b);
693+
};
694+
695+
auto select = [&](Value cond, Value t, Value f) -> Value {
696+
return builder.create<SelectOp>(cond, t, f);
697+
};
698+
699+
auto fmla = [&](Value a, Value b, Value c) {
700+
return builder.create<FmaFOp>(a, b, c);
701+
};
702+
703+
auto bitwiseOr = [&](Value a, Value b) { return builder.create<OrOp>(a, b); };
704+
705+
Value twoOverPi = bcast(f32Cst(builder, TWO_OVER_PI));
706+
Value piOverTwo = bcast(f32Cst(builder, PI_OVER_2));
707+
708+
Value x = op.operand();
709+
710+
Value k = floor(mul(x, twoOverPi));
711+
712+
Value y = sub(x, mul(k, piOverTwo));
713+
714+
Value cstOne = bcast(f32Cst(builder, 1.0));
715+
Value cstNegativeOne = bcast(f32Cst(builder, -1.0));
716+
717+
Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f));
718+
Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f));
719+
Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f));
720+
Value cstSC8 =
721+
bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f));
722+
Value cstSC10 =
723+
bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
724+
725+
Value cstCC2 = bcast(f32Cst(builder, -0.5f));
726+
Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f));
727+
Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f));
728+
Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f));
729+
Value cstCC10 =
730+
bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
731+
732+
Value kMod4 = modulo4(fPToSingedInteger(k));
733+
734+
Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0)));
735+
Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1)));
736+
Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2)));
737+
Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3)));
738+
739+
Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
740+
Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1)))
741+
: bitwiseOr(kR1, kR2);
742+
743+
Value y2 = mul(y, y);
744+
745+
Value base = select(sinuseCos, cstOne, y);
746+
Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
747+
Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
748+
Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
749+
Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
750+
Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
751+
752+
Value v1 = fmla(y2, cstC10, cstC8);
753+
Value v2 = fmla(y2, v1, cstC6);
754+
Value v3 = fmla(y2, v2, cstC4);
755+
Value v4 = fmla(y2, v3, cstC2);
756+
Value v5 = fmla(y2, v4, cstOne);
757+
Value v6 = mul(base, v5);
758+
759+
Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
760+
761+
rewriter.replaceOp(op, approximation);
762+
763+
return success();
764+
}
765+
633766
//----------------------------------------------------------------------------//
634767

635768
void mlir::populateMathPolynomialApproximationPatterns(
636769
RewritePatternSet &patterns) {
637770
patterns.add<TanhApproximation, LogApproximation, Log2Approximation,
638-
Log1pApproximation, ExpApproximation, ExpM1Approximation>(
771+
Log1pApproximation, ExpApproximation, ExpM1Approximation,
772+
SinAndCosApproximation<true, math::SinOp>,
773+
SinAndCosApproximation<false, math::CosOp>>(
639774
patterns.getContext());
640775
}

mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,83 @@ func @expm1() {
219219

220220
return
221221
}
222+
// -------------------------------------------------------------------------- //
223+
// Sin.
224+
// -------------------------------------------------------------------------- //
225+
func @sin() {
226+
// CHECK: 0
227+
%0 = constant 0.0 : f32
228+
%sin_0 = math.sin %0 : f32
229+
vector.print %sin_0 : f32
230+
231+
// CHECK: 0.707107
232+
%pi_over_4 = constant 0.78539816339 : f32
233+
%sin_pi_over_4 = math.sin %pi_over_4 : f32
234+
vector.print %sin_pi_over_4 : f32
235+
236+
// CHECK: 1
237+
%pi_over_2 = constant 1.57079632679 : f32
238+
%sin_pi_over_2 = math.sin %pi_over_2 : f32
239+
vector.print %sin_pi_over_2 : f32
240+
241+
242+
// CHECK: 0
243+
%pi = constant 3.14159265359 : f32
244+
%sin_pi = math.sin %pi : f32
245+
vector.print %sin_pi : f32
246+
247+
// CHECK: -1
248+
%pi_3_over_2 = constant 4.71238898038 : f32
249+
%sin_pi_3_over_2 = math.sin %pi_3_over_2 : f32
250+
vector.print %sin_pi_3_over_2 : f32
251+
252+
// CHECK: 0, 0.866025, -1
253+
%vec_x = constant dense<[9.42477796077, 2.09439510239, -1.57079632679]> : vector<3xf32>
254+
%sin_vec_x = math.sin %vec_x : vector<3xf32>
255+
vector.print %sin_vec_x : vector<3xf32>
256+
257+
return
258+
}
259+
260+
// -------------------------------------------------------------------------- //
261+
// cos.
262+
// -------------------------------------------------------------------------- //
263+
264+
func @cos() {
265+
// CHECK: 1
266+
%0 = constant 0.0 : f32
267+
%cos_0 = math.cos %0 : f32
268+
vector.print %cos_0 : f32
269+
270+
// CHECK: 0.707107
271+
%pi_over_4 = constant 0.78539816339 : f32
272+
%cos_pi_over_4 = math.cos %pi_over_4 : f32
273+
vector.print %cos_pi_over_4 : f32
274+
275+
//// CHECK: 0
276+
%pi_over_2 = constant 1.57079632679 : f32
277+
%cos_pi_over_2 = math.cos %pi_over_2 : f32
278+
vector.print %cos_pi_over_2 : f32
279+
280+
/// CHECK: -1
281+
%pi = constant 3.14159265359 : f32
282+
%cos_pi = math.cos %pi : f32
283+
vector.print %cos_pi : f32
284+
285+
// CHECK: 0
286+
%pi_3_over_2 = constant 4.71238898038 : f32
287+
%cos_pi_3_over_2 = math.cos %pi_3_over_2 : f32
288+
vector.print %cos_pi_3_over_2 : f32
289+
290+
// CHECK: -1, -0.5, 0
291+
%vec_x = constant dense<[9.42477796077, 2.09439510239, -1.57079632679]> : vector<3xf32>
292+
%cos_vec_x = math.cos %vec_x : vector<3xf32>
293+
vector.print %cos_vec_x : vector<3xf32>
294+
295+
296+
return
297+
}
298+
222299

223300
func @main() {
224301
call @tanh(): () -> ()
@@ -227,5 +304,7 @@ func @main() {
227304
call @log1p(): () -> ()
228305
call @exp(): () -> ()
229306
call @expm1(): () -> ()
307+
call @sin(): () -> ()
308+
call @cos(): () -> ()
230309
return
231310
}

0 commit comments

Comments
 (0)