-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][complex] Add a numerically-stable lowering for complex.expm1. #115082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Alexander Belyaev (pifon2a) ChangesThe current conversion to Standard in the MLIR repo is not stable for small imag(arg). Full diff: https://github.com/llvm/llvm-project/pull/115082.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 6656be830989a4..9ebb18a6c4ba70 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -520,29 +520,94 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
}
};
+Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
+ ArrayRef<double> coefficients,
+ arith::FastMathFlagsAttr fmf) {
+ auto argType = mlir::cast<FloatType>(arg.getType());
+ Value poly =
+ b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[0]));
+ for (int i = 1; i < coefficients.size(); ++i) {
+ poly = b.create<math::FmaOp>(
+ poly, arg,
+ b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[i])),
+ fmf);
+ }
+ return poly;
+}
+
struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
+ // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
+ // [handle inaccuracies when a and/or b are small]
+ // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
+ // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
LogicalResult
matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto type = cast<ComplexType>(adaptor.getComplex().getType());
- auto elementType = cast<FloatType>(type.getElementType());
+ auto type = op.getType();
+ auto elemType = mlir::cast<FloatType>(type.getElementType());
+
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value real = b.create<complex::ReOp>(adaptor.getComplex());
+ Value imag = b.create<complex::ImOp>(adaptor.getComplex());
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- Value exp = b.create<complex::ExpOp>(adaptor.getComplex(), fmf.getValue());
+ Value zero = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 0.0));
+ Value one = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 1.0));
- Value real = b.create<complex::ReOp>(elementType, exp);
- Value one = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 1));
- Value realMinusOne = b.create<arith::SubFOp>(real, one, fmf.getValue());
- Value imag = b.create<complex::ImOp>(elementType, exp);
+ Value expm1Real = b.create<math::ExpM1Op>(real, fmf);
+ Value expReal = b.create<arith::AddFOp>(expm1Real, one, fmf);
+
+ Value sinImag = b.create<math::SinOp>(imag, fmf);
+ Value cosm1Imag = emitCosm1(imag, fmf, b);
+ Value cosImag = b.create<arith::AddFOp>(cosm1Imag, one, fmf);
- rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
- imag);
+ Value realResult = b.create<arith::AddFOp>(
+ b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf);
+
+ Value imageIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag,
+ zero, fmf.getValue());
+ Value imagResult = b.create<arith::SelectOp>(
+ imageIsZero, zero, b.create<arith::MulFOp>(expReal, sinImag, fmf));
+
+ rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult,
+ imagResult);
return success();
}
+
+private:
+ Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf,
+ ImplicitLocOpBuilder &b) const {
+ auto argType = mlir::cast<FloatType>(arg.getType());
+ auto negHalf = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -0.5));
+ auto negOne = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -1.0));
+
+ // Algorithm copied from cephes cosm1.
+ SmallVector<double, 7> kCoeffs{
+ 4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
+ 2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
+ 2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
+ 4.1666666666666666609054E-2,
+ };
+ Value cos = b.create<math::CosOp>(arg, fmf);
+ Value forLargeArg = b.create<arith::AddFOp>(cos, negOne, fmf);
+
+ Value argPow2 = b.create<arith::MulFOp>(arg, arg, fmf);
+ Value argPow4 = b.create<arith::MulFOp>(argPow2, argPow2, fmf);
+ Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
+
+ auto forSmallArg =
+ b.create<arith::AddFOp>(b.create<arith::MulFOp>(argPow4, poly, fmf),
+ b.create<arith::MulFOp>(negHalf, argPow2, fmf));
+
+ // (pi/4)^2 is approximately 0.61685
+ Value piOver4Pow2 =
+ b.create<arith::ConstantOp>(b.getFloatAttr(argType, 0.61685));
+ Value cond = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2,
+ piOver4Pow2, fmf.getValue());
+ return b.create<arith::SelectOp>(cond, forLargeArg, forSmallArg);
+ }
};
struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index d7767bda08435f..1e2724e17d765e 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -221,26 +221,52 @@ func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
// -----
-// CHECK-LABEL: func.func @complex_expm1(
-// CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
+// CHECK-LABEL: func.func @complex_expm1(
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
func.func @complex_expm1(%arg: complex<f32>) -> complex<f32> {
- %expm1 = complex.expm1 %arg: complex<f32>
+ %expm1 = complex.expm1 %arg fastmath<nnan,contract> : complex<f32>
return %expm1 : complex<f32>
}
-// CHECK: %[[REAL_I:.*]] = complex.re %[[ARG]] : complex<f32>
-// CHECK: %[[IMAG_I:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[EXP:.*]] = math.exp %[[REAL_I]] : f32
-// CHECK: %[[COS:.*]] = math.cos %[[IMAG_I]] : f32
-// CHECK: %[[RES_REAL:.*]] = arith.mulf %[[EXP]], %[[COS]] : f32
-// CHECK: %[[SIN:.*]] = math.sin %[[IMAG_I]] : f32
-// CHECK: %[[RES_IMAG:.*]] = arith.mulf %[[EXP]], %[[SIN]] : f32
-// CHECK: %[[RES_EXP:.*]] = complex.create %[[RES_REAL]], %[[RES_IMAG]] : complex<f32>
-// CHECK: %[[REAL:.*]] = complex.re %[[RES_EXP]] : complex<f32>
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[REAL_M1:.*]] = arith.subf %[[REAL]], %[[ONE]] : f32
-// CHECK: %[[IMAG:.*]] = complex.im %[[RES_EXP]] : complex<f32>
-// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
-// CHECK: return %[[RES]] : complex<f32>
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[C1_F32:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[EXPM1:.*]] = math.expm1 %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_6:.*]] = arith.addf %[[EXPM1]], %[[C1_F32]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_7:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_8:.*]] = arith.constant -5.000000e-01 : f32
+// CHECK: %[[VAL_9:.*]] = arith.constant -1.000000e+00 : f32
+// CHECK: %[[VAL_10:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_11:.*]] = arith.addf %[[VAL_10]], %[[VAL_9]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_12:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_13:.*]] = arith.mulf %[[VAL_12]], %[[VAL_12]] fastmath<nnan,contract> : f32
+// CHECK: %[[COEF0:.*]] = arith.constant 4.73775072E-14 : f32
+// CHECK: %[[COEF1:.*]] = arith.constant -1.14702848E-11 : f32
+// CHECK: %[[FMA0:.*]] = math.fma %[[COEF0]], %[[VAL_12]], %[[COEF1]] fastmath<nnan,contract> : f32
+// CHECK: %[[COEF2:.*]] = arith.constant 2.08767537E-9 : f32
+// CHECK: %[[FMA1:.*]] = math.fma %[[FMA0]], %[[VAL_12]], %[[COEF2]] fastmath<nnan,contract> : f32
+// CHECK: %[[COEF3:.*]] = arith.constant -2.755732E-7 : f32
+// CHECK: %[[FMA2:.*]] = math.fma %[[FMA1]], %[[VAL_12]], %[[COEF3]] fastmath<nnan,contract> : f32
+// CHECK: %[[COEF4:.*]] = arith.constant 2.48015876E-5 : f32
+// CHECK: %[[FMA3:.*]] = math.fma %[[FMA2]], %[[VAL_12]], %[[COEF4]] fastmath<nnan,contract> : f32
+// CHECK: %[[COEF5:.*]] = arith.constant -0.00138888892 : f32
+// CHECK: %[[FMA4:.*]] = math.fma %[[FMA3]], %[[VAL_12]], %[[COEF5]] fastmath<nnan,contract> : f32
+// CHECK: %[[COEF6:.*]] = arith.constant 0.0416666679 : f32
+// CHECK: %[[FMA5:.*]] = math.fma %[[FMA4]], %[[VAL_12]], %[[COEF6]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_27:.*]] = arith.mulf %[[VAL_13]], %[[FMA5]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_28:.*]] = arith.mulf %[[VAL_8]], %[[VAL_12]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_29:.*]] = arith.addf %[[VAL_27]], %[[VAL_28]] : f32
+// CHECK: %[[VAL_30:.*]] = arith.constant 6.168500e-01 : f32
+// CHECK: %[[VAL_31:.*]] = arith.cmpf oge, %[[VAL_12]], %[[VAL_30]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_32:.*]] = arith.select %[[VAL_31]], %[[VAL_11]], %[[VAL_29]] : f32
+// CHECK: %[[VAL_33:.*]] = arith.addf %[[VAL_32]], %[[C1_F32]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_34:.*]] = arith.mulf %[[EXPM1]], %[[VAL_33]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_35:.*]] = arith.addf %[[VAL_34]], %[[VAL_32]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_36:.*]] = arith.cmpf oeq, %[[IMAG]], %[[C0_F32]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_37:.*]] = arith.mulf %[[VAL_6]], %[[VAL_7]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[C0_F32]], %[[VAL_37]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[VAL_35]], %[[VAL_38]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
// -----
@@ -882,29 +908,6 @@ func.func @complex_exp_with_fmf(%arg: complex<f32>) -> complex<f32> {
// -----
-// CHECK-LABEL: func.func @complex_expm1_with_fmf(
-// CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
-func.func @complex_expm1_with_fmf(%arg: complex<f32>) -> complex<f32> {
- %expm1 = complex.expm1 %arg fastmath<nnan,contract> : complex<f32>
- return %expm1 : complex<f32>
-}
-// CHECK: %[[REAL_I:.*]] = complex.re %[[ARG]] : complex<f32>
-// CHECK: %[[IMAG_I:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[EXP:.*]] = math.exp %[[REAL_I]] fastmath<nnan,contract> : f32
-// CHECK: %[[COS:.*]] = math.cos %[[IMAG_I]] fastmath<nnan,contract> : f32
-// CHECK: %[[RES_REAL:.*]] = arith.mulf %[[EXP]], %[[COS]] fastmath<nnan,contract> : f32
-// CHECK: %[[SIN:.*]] = math.sin %[[IMAG_I]] fastmath<nnan,contract> : f32
-// CHECK: %[[RES_IMAG:.*]] = arith.mulf %[[EXP]], %[[SIN]] fastmath<nnan,contract> : f32
-// CHECK: %[[RES_EXP:.*]] = complex.create %[[RES_REAL]], %[[RES_IMAG]] : complex<f32>
-// CHECK: %[[REAL:.*]] = complex.re %[[RES_EXP]] : complex<f32>
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[REAL_M1:.*]] = arith.subf %[[REAL]], %[[ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG:.*]] = complex.im %[[RES_EXP]] : complex<f32>
-// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
-// CHECK: return %[[RES]] : complex<f32>
-
-// -----
-
// CHECK-LABEL: func @complex_log_with_fmf
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
|
Value realResult = b.create<arith::AddFOp>( | ||
b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf); | ||
|
||
Value imageIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: imageIsZero -> imagIsZero?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
The current conversion to Standard in the MLIR repo is not stable for small imag(arg).
It was upstreamed in llvm/llvm-project#115082 (review) Now we can use complex-to-standard pass. Reverts d2e313c PiperOrigin-RevId: 698191660
It was upstreamed in llvm/llvm-project#115082 (review) Now we can use complex-to-standard pass. Reverts e0ccb4b PiperOrigin-RevId: 698191660
It was upstreamed in llvm/llvm-project#115082 (review) Now we can use complex-to-standard pass. Reverts d2e313c PiperOrigin-RevId: 698191660
It was upstreamed in llvm/llvm-project#115082 (review) Now we can use complex-to-standard pass. Reverts e0ccb4b PiperOrigin-RevId: 698191660
It was upstreamed in llvm/llvm-project#115082 (review) Now we can use complex-to-standard pass. Reverts d2e313c PiperOrigin-RevId: 702291891
It was upstreamed in llvm/llvm-project#115082 (review) Now we can use complex-to-standard pass. Reverts e0ccb4b PiperOrigin-RevId: 702291891
The current conversion to Standard in the MLIR repo is not stable for small imag(arg).