Skip to content

Commit 33e60f3

Browse files
authored
Fix rsqrt inaccuracies. (#88707)
The current lowering has issues with large/subnormal values. This po XLA's lowering and was verified using XLA's test suite and the MLIR-based emitters. This updates #88691 to also update the correctness test for rsqrt(0). I checked C++ and Python, they both agree the result should be (inf, nan). Updated the correctness test to match this.
1 parent 0822780 commit 33e60f3

File tree

3 files changed

+93
-14
lines changed

3 files changed

+93
-14
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@ using namespace mlir;
2727

2828
namespace {
2929

30-
// Returns the absolute value or its square root.
30+
enum class AbsFn { abs, sqrt, rsqrt };
31+
32+
// Returns the absolute value, its square root or its reciprocal square root.
3133
Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
32-
ImplicitLocOpBuilder &b, bool returnSqrt = false) {
34+
ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
3335
Value one = b.create<arith::ConstantOp>(real.getType(),
3436
b.getFloatAttr(real.getType(), 1.0));
3537

@@ -43,7 +45,13 @@ Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
4345
Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
4446
Value result;
4547

46-
if (returnSqrt) {
48+
if (fn == AbsFn::rsqrt) {
49+
ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmf);
50+
min = b.create<math::RsqrtOp>(min, fmf);
51+
max = b.create<math::RsqrtOp>(max, fmf);
52+
}
53+
54+
if (fn == AbsFn::sqrt) {
4755
Value quarter = b.create<arith::ConstantOp>(
4856
real.getType(), b.getFloatAttr(real.getType(), 0.25));
4957
// sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
@@ -863,7 +871,7 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
863871

864872
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
865873
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
866-
Value absSqrt = computeAbs(real, imag, fmf, b, /*returnSqrt=*/true);
874+
Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
867875
Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
868876
Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf);
869877
Value cos = b.create<math::CosOp>(sqrtArg, fmf);
@@ -1147,18 +1155,74 @@ struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
11471155
LogicalResult
11481156
matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
11491157
ConversionPatternRewriter &rewriter) const override {
1150-
mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
1158+
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
11511159
auto type = cast<ComplexType>(adaptor.getComplex().getType());
11521160
auto elementType = cast<FloatType>(type.getElementType());
11531161

1154-
Value c = builder.create<arith::ConstantOp>(
1155-
elementType, builder.getFloatAttr(elementType, -0.5));
1156-
Value d = builder.create<arith::ConstantOp>(
1157-
elementType, builder.getFloatAttr(elementType, 0));
1162+
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
1163+
1164+
auto cst = [&](APFloat v) {
1165+
return b.create<arith::ConstantOp>(elementType,
1166+
b.getFloatAttr(elementType, v));
1167+
};
1168+
const auto &floatSemantics = elementType.getFloatSemantics();
1169+
Value zero = cst(APFloat::getZero(floatSemantics));
1170+
Value inf = cst(APFloat::getInf(floatSemantics));
1171+
Value negHalf = b.create<arith::ConstantOp>(
1172+
elementType, b.getFloatAttr(elementType, -0.5));
1173+
Value nan = cst(APFloat::getNaN(floatSemantics));
1174+
1175+
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
1176+
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
1177+
Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
1178+
Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
1179+
Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf);
1180+
Value cos = b.create<math::CosOp>(rsqrtArg, fmf);
1181+
Value sin = b.create<math::SinOp>(rsqrtArg, fmf);
1182+
1183+
Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf);
1184+
Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf);
1185+
1186+
if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1187+
arith::FastMathFlags::ninf)) {
1188+
Value negOne = b.create<arith::ConstantOp>(
1189+
elementType, b.getFloatAttr(elementType, -1));
1190+
1191+
Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf);
1192+
Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf);
1193+
Value negImagSignedZero =
1194+
b.create<arith::MulFOp>(negOne, imagSignedZero, fmf);
11581195

1159-
rewriter.replaceOp(op,
1160-
{powOpConversionImpl(builder, type, adaptor.getComplex(),
1161-
c, d, op.getFastmath())});
1196+
Value absReal = b.create<math::AbsFOp>(real, fmf);
1197+
Value absImag = b.create<math::AbsFOp>(imag, fmf);
1198+
1199+
Value absImagIsInf =
1200+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
1201+
Value realIsNan =
1202+
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
1203+
Value realIsInf =
1204+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
1205+
Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan);
1206+
1207+
Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf);
1208+
1209+
resultReal =
1210+
b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
1211+
resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero,
1212+
resultImag);
1213+
}
1214+
1215+
Value isRealZero =
1216+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
1217+
Value isImagZero =
1218+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1219+
Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero);
1220+
1221+
resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal);
1222+
resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag);
1223+
1224+
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
1225+
resultImag);
11621226
return success();
11631227
}
11641228
};

mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,21 @@ func.func @complex_rsqrt(%arg: complex<f32>) -> complex<f32> {
837837
return %rsqrt : complex<f32>
838838
}
839839

840+
// CHECK-COUNT-5: arith.select
841+
// CHECK-NOT: arith.select
842+
843+
// -----
844+
845+
// CHECK-LABEL: func @complex_rsqrt_nnan_ninf
846+
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
847+
func.func @complex_rsqrt_nnan_ninf(%arg: complex<f32>) -> complex<f32> {
848+
%sqrt = complex.rsqrt %arg fastmath<nnan,ninf> : complex<f32>
849+
return %sqrt : complex<f32>
850+
}
851+
852+
// CHECK-COUNT-3: arith.select
853+
// CHECK-NOT: arith.select
854+
840855
// -----
841856

842857
// CHECK-LABEL: func.func @complex_angle
@@ -2103,4 +2118,4 @@ func.func @complex_tanh_with_fmf(%arg: complex<f32>) -> complex<f32> {
21032118
// CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex<f32>
21042119
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
21052120
// CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] fastmath<nnan,contract> : f32
2106-
// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>
2121+
// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>

mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ func.func @entry() {
242242
// CHECK-NEXT: 0.321
243243
// CHECK-NEXT: -0.776
244244
(0.0, 0.0),
245-
// CHECK-NEXT: nan
245+
// CHECK-NEXT: inf
246246
// CHECK-NEXT: nan
247247
(0.0, 1.0),
248248
// CHECK-NEXT: 0.707

0 commit comments

Comments
 (0)