Skip to content

Commit a3fb301

Browse files
authored
[mlir][math] Fix polynomial math.asin approximation (#101247)
The polynomial approximation for asin is only good between [-9/16, 9/16]. Values beyond that range must be remapped to achieve good numeric results. This is done by the equation below: `arcsin(x) = PI/2 - arcsin(sqrt(1.0 - x*x))`
1 parent 28a0792 commit a3fb301

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,34 @@ AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
861861
return builder.create<arith::MulFOp>(a, b);
862862
};
863863

864-
Value s = mul(operand, operand);
864+
auto sub = [&](Value a, Value b) -> Value {
865+
return builder.create<arith::SubFOp>(a, b);
866+
};
867+
868+
auto abs = [&](Value a) -> Value { return builder.create<math::AbsFOp>(a); };
869+
870+
auto sqrt = [&](Value a) -> Value { return builder.create<math::SqrtOp>(a); };
871+
872+
auto scopy = [&](Value a, Value b) -> Value {
873+
return builder.create<math::CopySignOp>(a, b);
874+
};
875+
876+
auto sel = [&](Value a, Value b, Value c) -> Value {
877+
return builder.create<arith::SelectOp>(a, b, c);
878+
};
879+
880+
Value abso = abs(operand);
881+
Value aa = mul(operand, operand);
882+
Value opp = sqrt(sub(bcast(floatCst(builder, 1.0, elementType)), aa));
883+
884+
Value gt =
885+
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, aa,
886+
bcast(floatCst(builder, 0.5, elementType)));
887+
888+
Value x = sel(gt, opp, abso);
889+
890+
// Asin(x) approximation for x = [-9/16, 9/16]:
891+
Value s = mul(x, x);
865892
Value q = mul(s, s);
866893
Value r = bcast(floatCst(builder, 5.5579749017470502e-2, elementType));
867894
Value t = bcast(floatCst(builder, -6.2027913464120114e-2, elementType));
@@ -878,8 +905,12 @@ AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
878905
t = fma(t, q, bcast(floatCst(builder, 7.4999999991367292e-2, elementType)));
879906
r = fma(r, s, t);
880907
r = fma(r, s, bcast(floatCst(builder, 1.6666666666670193e-1, elementType)));
881-
t = mul(operand, s);
882-
r = fma(r, t, operand);
908+
t = mul(x, s);
909+
r = fma(r, t, x);
910+
911+
Value rsub = sub(bcast(floatCst(builder, 1.57079632679, elementType)), r);
912+
r = sel(gt, rsub, r);
913+
r = scopy(r, operand);
883914

884915
rewriter.replaceOp(op, r);
885916
return success();

mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,10 @@ func.func @asin() {
493493
%cst3 = arith.constant -0.25 : f32
494494
call @asin_f32(%cst3) : (f32) -> ()
495495

496+
// CHECK: -1.1197
497+
%cst4 = arith.constant -0.90 : f32
498+
call @asin_f32(%cst4) : (f32) -> ()
499+
496500
// CHECK: 0.25268, 0.384397, 0.597406
497501
%vec_x = arith.constant dense<[0.25, 0.375, 0.5625]> : vector<3xf32>
498502
call @asin_3xf32(%vec_x) : (vector<3xf32>) -> ()

0 commit comments

Comments
 (0)