Skip to content

[mlir][complex] Add a numerically-stable lowering for complex.expm1. #115082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 18, 2024

Conversation

pifon2a
Copy link
Contributor

@pifon2a pifon2a commented Nov 5, 2024

The current conversion to Standard in the MLIR repo is not stable for small imag(arg).

@pifon2a pifon2a requested review from Lewuathe and akuegel November 5, 2024 22:48
@llvmbot llvmbot added the mlir label Nov 5, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 5, 2024

@llvm/pr-subscribers-mlir

Author: Alexander Belyaev (pifon2a)

Changes

The current conversion to Standard in the MLIR repo is not stable for small imag(arg).


Full diff: https://github.com/llvm/llvm-project/pull/115082.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+76-11)
  • (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+43-40)
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 6656be830989a4..9ebb18a6c4ba70 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -520,29 +520,94 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
   }
 };
 
+Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
+                         ArrayRef<double> coefficients,
+                         arith::FastMathFlagsAttr fmf) {
+  auto argType = mlir::cast<FloatType>(arg.getType());
+  Value poly =
+      b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[0]));
+  for (int i = 1; i < coefficients.size(); ++i) {
+    poly = b.create<math::FmaOp>(
+        poly, arg,
+        b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[i])),
+        fmf);
+  }
+  return poly;
+}
+
 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
   using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
 
+  // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
+  //            [handle inaccuracies when a and/or b are small]
+  //            = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
+  //            = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
   LogicalResult
   matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto type = cast<ComplexType>(adaptor.getComplex().getType());
-    auto elementType = cast<FloatType>(type.getElementType());
+    auto type = op.getType();
+    auto elemType = mlir::cast<FloatType>(type.getElementType());
+
     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value real = b.create<complex::ReOp>(adaptor.getComplex());
+    Value imag = b.create<complex::ImOp>(adaptor.getComplex());
 
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    Value exp = b.create<complex::ExpOp>(adaptor.getComplex(), fmf.getValue());
+    Value zero = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 0.0));
+    Value one = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 1.0));
 
-    Value real = b.create<complex::ReOp>(elementType, exp);
-    Value one = b.create<arith::ConstantOp>(elementType,
-                                            b.getFloatAttr(elementType, 1));
-    Value realMinusOne = b.create<arith::SubFOp>(real, one, fmf.getValue());
-    Value imag = b.create<complex::ImOp>(elementType, exp);
+    Value expm1Real = b.create<math::ExpM1Op>(real, fmf);
+    Value expReal = b.create<arith::AddFOp>(expm1Real, one, fmf);
+
+    Value sinImag = b.create<math::SinOp>(imag, fmf);
+    Value cosm1Imag = emitCosm1(imag, fmf, b);
+    Value cosImag = b.create<arith::AddFOp>(cosm1Imag, one, fmf);
 
-    rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
-                                                   imag);
+    Value realResult = b.create<arith::AddFOp>(
+        b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf);
+
+    Value imageIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag,
+                                                zero, fmf.getValue());
+    Value imagResult = b.create<arith::SelectOp>(
+        imageIsZero, zero, b.create<arith::MulFOp>(expReal, sinImag, fmf));
+
+    rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult,
+                                                   imagResult);
     return success();
   }
+
+private:
+  Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf,
+                  ImplicitLocOpBuilder &b) const {
+    auto argType = mlir::cast<FloatType>(arg.getType());
+    auto negHalf = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -0.5));
+    auto negOne = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -1.0));
+
+    // Algorithm copied from cephes cosm1.
+    SmallVector<double, 7> kCoeffs{
+        4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
+        2.0876754287081521758361E-9,  -2.7557319214999787979814E-7,
+        2.4801587301570552304991E-5,  -1.3888888888888872993737E-3,
+        4.1666666666666666609054E-2,
+    };
+    Value cos = b.create<math::CosOp>(arg, fmf);
+    Value forLargeArg = b.create<arith::AddFOp>(cos, negOne, fmf);
+
+    Value argPow2 = b.create<arith::MulFOp>(arg, arg, fmf);
+    Value argPow4 = b.create<arith::MulFOp>(argPow2, argPow2, fmf);
+    Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
+
+    auto forSmallArg =
+        b.create<arith::AddFOp>(b.create<arith::MulFOp>(argPow4, poly, fmf),
+                                b.create<arith::MulFOp>(negHalf, argPow2, fmf));
+
+    // (pi/4)^2 is approximately 0.61685
+    Value piOver4Pow2 =
+        b.create<arith::ConstantOp>(b.getFloatAttr(argType, 0.61685));
+    Value cond = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2,
+                                         piOver4Pow2, fmf.getValue());
+    return b.create<arith::SelectOp>(cond, forLargeArg, forSmallArg);
+  }
 };
 
 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index d7767bda08435f..1e2724e17d765e 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -221,26 +221,52 @@ func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
 
 // -----
 
-// CHECK-LABEL:   func.func @complex_expm1(
-// CHECK-SAME:                             %[[ARG:.*]]: complex<f32>) -> complex<f32> {
+// CHECK-LABEL: func.func @complex_expm1(
+// CHECK-SAME:    %[[ARG:.*]]: complex<f32>) -> complex<f32> {
 func.func @complex_expm1(%arg: complex<f32>) -> complex<f32> {
-  %expm1 = complex.expm1 %arg: complex<f32>
+  %expm1 = complex.expm1 %arg fastmath<nnan,contract> : complex<f32>
   return %expm1 : complex<f32>
 }
-// CHECK: %[[REAL_I:.*]] = complex.re %[[ARG]] : complex<f32>
-// CHECK: %[[IMAG_I:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[EXP:.*]] = math.exp %[[REAL_I]] : f32
-// CHECK: %[[COS:.*]] = math.cos %[[IMAG_I]] : f32
-// CHECK: %[[RES_REAL:.*]] = arith.mulf %[[EXP]], %[[COS]] : f32
-// CHECK: %[[SIN:.*]] = math.sin %[[IMAG_I]] : f32
-// CHECK: %[[RES_IMAG:.*]] = arith.mulf %[[EXP]], %[[SIN]] : f32
-// CHECK: %[[RES_EXP:.*]] = complex.create %[[RES_REAL]], %[[RES_IMAG]] : complex<f32>
-// CHECK: %[[REAL:.*]] = complex.re %[[RES_EXP]] : complex<f32>
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[REAL_M1:.*]] = arith.subf %[[REAL]], %[[ONE]] : f32
-// CHECK: %[[IMAG:.*]] = complex.im %[[RES_EXP]] : complex<f32>
-// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
-// CHECK: return %[[RES]] : complex<f32>
+// CHECK:  %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK:  %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK:  %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:  %[[C1_F32:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK:  %[[EXPM1:.*]] = math.expm1 %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_6:.*]] = arith.addf %[[EXPM1]], %[[C1_F32]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_7:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_8:.*]] = arith.constant -5.000000e-01 : f32
+// CHECK:  %[[VAL_9:.*]] = arith.constant -1.000000e+00 : f32
+// CHECK:  %[[VAL_10:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_11:.*]] = arith.addf %[[VAL_10]], %[[VAL_9]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_12:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_13:.*]] = arith.mulf %[[VAL_12]], %[[VAL_12]] fastmath<nnan,contract> : f32
+// CHECK:  %[[COEF0:.*]] = arith.constant 4.73775072E-14 : f32
+// CHECK:  %[[COEF1:.*]] = arith.constant -1.14702848E-11 : f32
+// CHECK:  %[[FMA0:.*]] = math.fma %[[COEF0]], %[[VAL_12]], %[[COEF1]] fastmath<nnan,contract> : f32
+// CHECK:  %[[COEF2:.*]] = arith.constant 2.08767537E-9 : f32
+// CHECK:  %[[FMA1:.*]] = math.fma %[[FMA0]], %[[VAL_12]], %[[COEF2]] fastmath<nnan,contract> : f32
+// CHECK:  %[[COEF3:.*]] = arith.constant -2.755732E-7 : f32
+// CHECK:  %[[FMA2:.*]] = math.fma %[[FMA1]], %[[VAL_12]], %[[COEF3]] fastmath<nnan,contract> : f32
+// CHECK:  %[[COEF4:.*]] = arith.constant 2.48015876E-5 : f32
+// CHECK:  %[[FMA3:.*]] = math.fma %[[FMA2]], %[[VAL_12]], %[[COEF4]] fastmath<nnan,contract> : f32
+// CHECK:  %[[COEF5:.*]] = arith.constant -0.00138888892 : f32
+// CHECK:  %[[FMA4:.*]] = math.fma %[[FMA3]], %[[VAL_12]], %[[COEF5]] fastmath<nnan,contract> : f32
+// CHECK:  %[[COEF6:.*]] = arith.constant 0.0416666679 : f32
+// CHECK:  %[[FMA5:.*]] = math.fma %[[FMA4]], %[[VAL_12]], %[[COEF6]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_27:.*]] = arith.mulf %[[VAL_13]], %[[FMA5]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_28:.*]] = arith.mulf %[[VAL_8]], %[[VAL_12]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_29:.*]] = arith.addf %[[VAL_27]], %[[VAL_28]] : f32
+// CHECK:  %[[VAL_30:.*]] = arith.constant 6.168500e-01 : f32
+// CHECK:  %[[VAL_31:.*]] = arith.cmpf oge, %[[VAL_12]], %[[VAL_30]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_32:.*]] = arith.select %[[VAL_31]], %[[VAL_11]], %[[VAL_29]] : f32
+// CHECK:  %[[VAL_33:.*]] = arith.addf %[[VAL_32]], %[[C1_F32]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_34:.*]] = arith.mulf %[[EXPM1]], %[[VAL_33]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_35:.*]] = arith.addf %[[VAL_34]], %[[VAL_32]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_36:.*]] = arith.cmpf oeq, %[[IMAG]], %[[C0_F32]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_37:.*]] = arith.mulf %[[VAL_6]], %[[VAL_7]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[C0_F32]], %[[VAL_37]] : f32
+// CHECK:  %[[RESULT:.*]] = complex.create %[[VAL_35]], %[[VAL_38]] : complex<f32>
+// CHECK:  return %[[RESULT]] : complex<f32>
 
 // -----
 
@@ -882,29 +908,6 @@ func.func @complex_exp_with_fmf(%arg: complex<f32>) -> complex<f32> {
 
 // -----
 
-// CHECK-LABEL:   func.func @complex_expm1_with_fmf(
-// CHECK-SAME:                             %[[ARG:.*]]: complex<f32>) -> complex<f32> {
-func.func @complex_expm1_with_fmf(%arg: complex<f32>) -> complex<f32> {
-  %expm1 = complex.expm1 %arg fastmath<nnan,contract> : complex<f32>
-  return %expm1 : complex<f32>
-}
-// CHECK: %[[REAL_I:.*]] = complex.re %[[ARG]] : complex<f32>
-// CHECK: %[[IMAG_I:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[EXP:.*]] = math.exp %[[REAL_I]] fastmath<nnan,contract> : f32
-// CHECK: %[[COS:.*]] = math.cos %[[IMAG_I]] fastmath<nnan,contract> : f32
-// CHECK: %[[RES_REAL:.*]] = arith.mulf %[[EXP]], %[[COS]] fastmath<nnan,contract> : f32
-// CHECK: %[[SIN:.*]] = math.sin %[[IMAG_I]] fastmath<nnan,contract> : f32
-// CHECK: %[[RES_IMAG:.*]] = arith.mulf %[[EXP]], %[[SIN]] fastmath<nnan,contract> : f32
-// CHECK: %[[RES_EXP:.*]] = complex.create %[[RES_REAL]], %[[RES_IMAG]] : complex<f32>
-// CHECK: %[[REAL:.*]] = complex.re %[[RES_EXP]] : complex<f32>
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[REAL_M1:.*]] = arith.subf %[[REAL]], %[[ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG:.*]] = complex.im %[[RES_EXP]] : complex<f32>
-// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
-// CHECK: return %[[RES]] : complex<f32>
-
-// -----
-
 // CHECK-LABEL: func @complex_log_with_fmf
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {

Value realResult = b.create<arith::AddFOp>(
b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf);

Value imageIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: imageIsZero -> imagIsZero?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

The current conversion to Standard in the MLIR repo is not stable for small
imag(arg).
@pifon2a pifon2a merged commit 18ee003 into llvm:main Nov 18, 2024
8 checks passed
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Nov 20, 2024
It was upstreamed in llvm/llvm-project#115082 (review)
Now we can use complex-to-standard pass.

Reverts d2e313c

PiperOrigin-RevId: 698191660
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 20, 2024
It was upstreamed in llvm/llvm-project#115082 (review)
Now we can use complex-to-standard pass.

Reverts e0ccb4b

PiperOrigin-RevId: 698191660
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Dec 3, 2024
It was upstreamed in llvm/llvm-project#115082 (review)
Now we can use complex-to-standard pass.

Reverts d2e313c

PiperOrigin-RevId: 698191660
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Dec 3, 2024
It was upstreamed in llvm/llvm-project#115082 (review)
Now we can use complex-to-standard pass.

Reverts e0ccb4b

PiperOrigin-RevId: 698191660
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Dec 3, 2024
It was upstreamed in llvm/llvm-project#115082 (review)
Now we can use complex-to-standard pass.

Reverts d2e313c

PiperOrigin-RevId: 702291891
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Dec 3, 2024
It was upstreamed in llvm/llvm-project#115082 (review)
Now we can use complex-to-standard pass.

Reverts e0ccb4b

PiperOrigin-RevId: 702291891
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants