Skip to content

Commit 18ee003

Browse files
authored
[mlir][complex] Add a numerically-stable lowering for complex.expm1. (#115082)
The current conversion to Standard in the MLIR repo is not stable for small imag(arg).
1 parent 90e9223 commit 18ee003

File tree

2 files changed

+120
-52
lines changed

2 files changed

+120
-52
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 76 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -520,29 +520,94 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
520520
}
521521
};
522522

523+
Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
524+
ArrayRef<double> coefficients,
525+
arith::FastMathFlagsAttr fmf) {
526+
auto argType = mlir::cast<FloatType>(arg.getType());
527+
Value poly =
528+
b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[0]));
529+
for (int i = 1; i < coefficients.size(); ++i) {
530+
poly = b.create<math::FmaOp>(
531+
poly, arg,
532+
b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[i])),
533+
fmf);
534+
}
535+
return poly;
536+
}
537+
523538
struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
524539
using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
525540

541+
// e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
542+
// [handle inaccuracies when a and/or b are small]
543+
// = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
544+
// = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
526545
LogicalResult
527546
matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
528547
ConversionPatternRewriter &rewriter) const override {
529-
auto type = cast<ComplexType>(adaptor.getComplex().getType());
530-
auto elementType = cast<FloatType>(type.getElementType());
548+
auto type = op.getType();
549+
auto elemType = mlir::cast<FloatType>(type.getElementType());
550+
531551
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
552+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
553+
Value real = b.create<complex::ReOp>(adaptor.getComplex());
554+
Value imag = b.create<complex::ImOp>(adaptor.getComplex());
532555

533-
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
534-
Value exp = b.create<complex::ExpOp>(adaptor.getComplex(), fmf.getValue());
556+
Value zero = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 0.0));
557+
Value one = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 1.0));
535558

536-
Value real = b.create<complex::ReOp>(elementType, exp);
537-
Value one = b.create<arith::ConstantOp>(elementType,
538-
b.getFloatAttr(elementType, 1));
539-
Value realMinusOne = b.create<arith::SubFOp>(real, one, fmf.getValue());
540-
Value imag = b.create<complex::ImOp>(elementType, exp);
559+
Value expm1Real = b.create<math::ExpM1Op>(real, fmf);
560+
Value expReal = b.create<arith::AddFOp>(expm1Real, one, fmf);
561+
562+
Value sinImag = b.create<math::SinOp>(imag, fmf);
563+
Value cosm1Imag = emitCosm1(imag, fmf, b);
564+
Value cosImag = b.create<arith::AddFOp>(cosm1Imag, one, fmf);
541565

542-
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
543-
imag);
566+
Value realResult = b.create<arith::AddFOp>(
567+
b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf);
568+
569+
Value imagIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag,
570+
zero, fmf.getValue());
571+
Value imagResult = b.create<arith::SelectOp>(
572+
imagIsZero, zero, b.create<arith::MulFOp>(expReal, sinImag, fmf));
573+
574+
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult,
575+
imagResult);
544576
return success();
545577
}
578+
579+
private:
580+
Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf,
581+
ImplicitLocOpBuilder &b) const {
582+
auto argType = mlir::cast<FloatType>(arg.getType());
583+
auto negHalf = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -0.5));
584+
auto negOne = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -1.0));
585+
586+
// Algorithm copied from cephes cosm1.
587+
SmallVector<double, 7> kCoeffs{
588+
4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
589+
2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
590+
2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
591+
4.1666666666666666609054E-2,
592+
};
593+
Value cos = b.create<math::CosOp>(arg, fmf);
594+
Value forLargeArg = b.create<arith::AddFOp>(cos, negOne, fmf);
595+
596+
Value argPow2 = b.create<arith::MulFOp>(arg, arg, fmf);
597+
Value argPow4 = b.create<arith::MulFOp>(argPow2, argPow2, fmf);
598+
Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
599+
600+
auto forSmallArg =
601+
b.create<arith::AddFOp>(b.create<arith::MulFOp>(argPow4, poly, fmf),
602+
b.create<arith::MulFOp>(negHalf, argPow2, fmf));
603+
604+
// (pi/4)^2 is approximately 0.61685
605+
Value piOver4Pow2 =
606+
b.create<arith::ConstantOp>(b.getFloatAttr(argType, 0.61685));
607+
Value cond = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2,
608+
piOver4Pow2, fmf.getValue());
609+
return b.create<arith::SelectOp>(cond, forLargeArg, forSmallArg);
610+
}
546611
};
547612

548613
struct LogOpConversion : public OpConversionPattern<complex::LogOp> {

mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Lines changed: 44 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -221,26 +221,52 @@ func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
221221

222222
// -----
223223

224-
// CHECK-LABEL: func.func @complex_expm1(
225-
// CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
224+
// CHECK-LABEL: func.func @complex_expm1(
225+
// CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
226226
func.func @complex_expm1(%arg: complex<f32>) -> complex<f32> {
227-
%expm1 = complex.expm1 %arg: complex<f32>
227+
%expm1 = complex.expm1 %arg fastmath<nnan,contract> : complex<f32>
228228
return %expm1 : complex<f32>
229229
}
230-
// CHECK: %[[REAL_I:.*]] = complex.re %[[ARG]] : complex<f32>
231-
// CHECK: %[[IMAG_I:.*]] = complex.im %[[ARG]] : complex<f32>
232-
// CHECK: %[[EXP:.*]] = math.exp %[[REAL_I]] : f32
233-
// CHECK: %[[COS:.*]] = math.cos %[[IMAG_I]] : f32
234-
// CHECK: %[[RES_REAL:.*]] = arith.mulf %[[EXP]], %[[COS]] : f32
235-
// CHECK: %[[SIN:.*]] = math.sin %[[IMAG_I]] : f32
236-
// CHECK: %[[RES_IMAG:.*]] = arith.mulf %[[EXP]], %[[SIN]] : f32
237-
// CHECK: %[[RES_EXP:.*]] = complex.create %[[RES_REAL]], %[[RES_IMAG]] : complex<f32>
238-
// CHECK: %[[REAL:.*]] = complex.re %[[RES_EXP]] : complex<f32>
239-
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
240-
// CHECK: %[[REAL_M1:.*]] = arith.subf %[[REAL]], %[[ONE]] : f32
241-
// CHECK: %[[IMAG:.*]] = complex.im %[[RES_EXP]] : complex<f32>
242-
// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
243-
// CHECK: return %[[RES]] : complex<f32>
230+
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
231+
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
232+
// CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
233+
// CHECK-DAG: %[[C1_F32:.*]] = arith.constant 1.000000e+00 : f32
234+
// CHECK: %[[EXPM1:.*]] = math.expm1 %[[REAL]] fastmath<nnan,contract> : f32
235+
// CHECK: %[[VAL_6:.*]] = arith.addf %[[EXPM1]], %[[C1_F32]] fastmath<nnan,contract> : f32
236+
// CHECK: %[[VAL_7:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
237+
// CHECK: %[[VAL_8:.*]] = arith.constant -5.000000e-01 : f32
238+
// CHECK: %[[VAL_9:.*]] = arith.constant -1.000000e+00 : f32
239+
// CHECK: %[[VAL_10:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
240+
// CHECK: %[[VAL_11:.*]] = arith.addf %[[VAL_10]], %[[VAL_9]] fastmath<nnan,contract> : f32
241+
// CHECK: %[[VAL_12:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
242+
// CHECK: %[[VAL_13:.*]] = arith.mulf %[[VAL_12]], %[[VAL_12]] fastmath<nnan,contract> : f32
243+
// CHECK-DAG: %[[COEF0:.*]] = arith.constant 4.73775072E-14 : f32
244+
// CHECK-DAG: %[[COEF1:.*]] = arith.constant -1.14702848E-11 : f32
245+
// CHECK: %[[FMA0:.*]] = math.fma %[[COEF0]], %[[VAL_12]], %[[COEF1]] fastmath<nnan,contract> : f32
246+
// CHECK: %[[COEF2:.*]] = arith.constant 2.08767537E-9 : f32
247+
// CHECK: %[[FMA1:.*]] = math.fma %[[FMA0]], %[[VAL_12]], %[[COEF2]] fastmath<nnan,contract> : f32
248+
// CHECK: %[[COEF3:.*]] = arith.constant -2.755732E-7 : f32
249+
// CHECK: %[[FMA2:.*]] = math.fma %[[FMA1]], %[[VAL_12]], %[[COEF3]] fastmath<nnan,contract> : f32
250+
// CHECK: %[[COEF4:.*]] = arith.constant 2.48015876E-5 : f32
251+
// CHECK: %[[FMA3:.*]] = math.fma %[[FMA2]], %[[VAL_12]], %[[COEF4]] fastmath<nnan,contract> : f32
252+
// CHECK: %[[COEF5:.*]] = arith.constant -0.00138888892 : f32
253+
// CHECK: %[[FMA4:.*]] = math.fma %[[FMA3]], %[[VAL_12]], %[[COEF5]] fastmath<nnan,contract> : f32
254+
// CHECK: %[[COEF6:.*]] = arith.constant 0.0416666679 : f32
255+
// CHECK: %[[FMA5:.*]] = math.fma %[[FMA4]], %[[VAL_12]], %[[COEF6]] fastmath<nnan,contract> : f32
256+
// CHECK-DAG: %[[VAL_27:.*]] = arith.mulf %[[VAL_13]], %[[FMA5]] fastmath<nnan,contract> : f32
257+
// CHECK-DAG: %[[VAL_28:.*]] = arith.mulf %[[VAL_8]], %[[VAL_12]] fastmath<nnan,contract> : f32
258+
// CHECK: %[[VAL_29:.*]] = arith.addf %[[VAL_27]], %[[VAL_28]] : f32
259+
// CHECK: %[[VAL_30:.*]] = arith.constant 6.168500e-01 : f32
260+
// CHECK: %[[VAL_31:.*]] = arith.cmpf oge, %[[VAL_12]], %[[VAL_30]] fastmath<nnan,contract> : f32
261+
// CHECK: %[[VAL_32:.*]] = arith.select %[[VAL_31]], %[[VAL_11]], %[[VAL_29]] : f32
262+
// CHECK: %[[VAL_33:.*]] = arith.addf %[[VAL_32]], %[[C1_F32]] fastmath<nnan,contract> : f32
263+
// CHECK: %[[VAL_34:.*]] = arith.mulf %[[EXPM1]], %[[VAL_33]] fastmath<nnan,contract> : f32
264+
// CHECK: %[[VAL_35:.*]] = arith.addf %[[VAL_34]], %[[VAL_32]] fastmath<nnan,contract> : f32
265+
// CHECK: %[[VAL_36:.*]] = arith.cmpf oeq, %[[IMAG]], %[[C0_F32]] fastmath<nnan,contract> : f32
266+
// CHECK: %[[VAL_37:.*]] = arith.mulf %[[VAL_6]], %[[VAL_7]] fastmath<nnan,contract> : f32
267+
// CHECK: %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[C0_F32]], %[[VAL_37]] : f32
268+
// CHECK: %[[RESULT:.*]] = complex.create %[[VAL_35]], %[[VAL_38]] : complex<f32>
269+
// CHECK: return %[[RESULT]] : complex<f32>
244270

245271
// -----
246272

@@ -882,29 +908,6 @@ func.func @complex_exp_with_fmf(%arg: complex<f32>) -> complex<f32> {
882908

883909
// -----
884910

885-
// CHECK-LABEL: func.func @complex_expm1_with_fmf(
886-
// CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
887-
func.func @complex_expm1_with_fmf(%arg: complex<f32>) -> complex<f32> {
888-
%expm1 = complex.expm1 %arg fastmath<nnan,contract> : complex<f32>
889-
return %expm1 : complex<f32>
890-
}
891-
// CHECK: %[[REAL_I:.*]] = complex.re %[[ARG]] : complex<f32>
892-
// CHECK: %[[IMAG_I:.*]] = complex.im %[[ARG]] : complex<f32>
893-
// CHECK: %[[EXP:.*]] = math.exp %[[REAL_I]] fastmath<nnan,contract> : f32
894-
// CHECK: %[[COS:.*]] = math.cos %[[IMAG_I]] fastmath<nnan,contract> : f32
895-
// CHECK: %[[RES_REAL:.*]] = arith.mulf %[[EXP]], %[[COS]] fastmath<nnan,contract> : f32
896-
// CHECK: %[[SIN:.*]] = math.sin %[[IMAG_I]] fastmath<nnan,contract> : f32
897-
// CHECK: %[[RES_IMAG:.*]] = arith.mulf %[[EXP]], %[[SIN]] fastmath<nnan,contract> : f32
898-
// CHECK: %[[RES_EXP:.*]] = complex.create %[[RES_REAL]], %[[RES_IMAG]] : complex<f32>
899-
// CHECK: %[[REAL:.*]] = complex.re %[[RES_EXP]] : complex<f32>
900-
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
901-
// CHECK: %[[REAL_M1:.*]] = arith.subf %[[REAL]], %[[ONE]] fastmath<nnan,contract> : f32
902-
// CHECK: %[[IMAG:.*]] = complex.im %[[RES_EXP]] : complex<f32>
903-
// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
904-
// CHECK: return %[[RES]] : complex<f32>
905-
906-
// -----
907-
908911
// CHECK-LABEL: func @complex_log_with_fmf
909912
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
910913
func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
@@ -2020,4 +2023,4 @@ func.func @complex_angle_with_fmf(%arg: complex<f32>) -> f32 {
20202023
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
20212024
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
20222025
// CHECK: %[[RESULT:.*]] = math.atan2 %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
2023-
// CHECK: return %[[RESULT]] : f32
2026+
// CHECK: return %[[RESULT]] : f32

0 commit comments

Comments
 (0)