Skip to content

Commit d230bf3

Browse files
authored
[mlir][complex] Support Fastmath flag in the conversion of exp,expm1 (llvm#67001)
See: https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981
1 parent 5ba239f commit d230bf3

File tree

2 files changed

+52
-7
lines changed

2 files changed

+52
-7
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -446,16 +446,19 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
446446
auto loc = op.getLoc();
447447
auto type = cast<ComplexType>(adaptor.getComplex().getType());
448448
auto elementType = cast<FloatType>(type.getElementType());
449+
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
449450

450451
Value real =
451452
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
452453
Value imag =
453454
rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
454-
Value expReal = rewriter.create<math::ExpOp>(loc, real);
455-
Value cosImag = rewriter.create<math::CosOp>(loc, imag);
456-
Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag);
457-
Value sinImag = rewriter.create<math::SinOp>(loc, imag);
458-
Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag);
455+
Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue());
456+
Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue());
457+
Value resultReal =
458+
rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue());
459+
Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue());
460+
Value resultImag =
461+
rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue());
459462

460463
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
461464
resultImag);
@@ -471,14 +474,15 @@ struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
471474
ConversionPatternRewriter &rewriter) const override {
472475
auto type = cast<ComplexType>(adaptor.getComplex().getType());
473476
auto elementType = cast<FloatType>(type.getElementType());
477+
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
474478

475479
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
476-
Value exp = b.create<complex::ExpOp>(adaptor.getComplex());
480+
Value exp = b.create<complex::ExpOp>(adaptor.getComplex(), fmf.getValue());
477481

478482
Value real = b.create<complex::ReOp>(elementType, exp);
479483
Value one = b.create<arith::ConstantOp>(elementType,
480484
b.getFloatAttr(elementType, 1));
481-
Value realMinusOne = b.create<arith::SubFOp>(real, one);
485+
Value realMinusOne = b.create<arith::SubFOp>(real, one, fmf.getValue());
482486
Value imag = b.create<complex::ImOp>(elementType, exp);
483487

484488
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,

mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,3 +757,44 @@ func.func @complex_sub_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> compl
757757
// CHECK: %[[RESULT_IMAG:.*]] = arith.subf %[[IMAG_LHS]], %[[IMAG_RHS]] fastmath<nnan,contract> : f32
758758
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
759759
// CHECK: return %[[RESULT]] : complex<f32>
760+
761+
// -----
762+
763+
// CHECK-LABEL: func @complex_exp_with_fmf
764+
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
765+
func.func @complex_exp_with_fmf(%arg: complex<f32>) -> complex<f32> {
766+
%exp = complex.exp %arg fastmath<nnan,contract> : complex<f32>
767+
return %exp : complex<f32>
768+
}
769+
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
770+
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
771+
// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
772+
// CHECK-DAG: %[[EXP_REAL:.*]] = math.exp %[[REAL]] fastmath<nnan,contract> : f32
773+
// CHECK-DAG: %[[RESULT_REAL:.]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] fastmath<nnan,contract> : f32
774+
// CHECK-DAG: %[[SIN_IMAG:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
775+
// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] fastmath<nnan,contract> : f32
776+
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
777+
// CHECK: return %[[RESULT]] : complex<f32>
778+
779+
// -----
780+
781+
// CHECK-LABEL: func.func @complex_expm1_with_fmf(
782+
// CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
783+
func.func @complex_expm1_with_fmf(%arg: complex<f32>) -> complex<f32> {
784+
%expm1 = complex.expm1 %arg fastmath<nnan,contract> : complex<f32>
785+
return %expm1 : complex<f32>
786+
}
787+
// CHECK: %[[REAL_I:.*]] = complex.re %[[ARG]] : complex<f32>
788+
// CHECK: %[[IMAG_I:.*]] = complex.im %[[ARG]] : complex<f32>
789+
// CHECK: %[[EXP:.*]] = math.exp %[[REAL_I]] fastmath<nnan,contract> : f32
790+
// CHECK: %[[COS:.*]] = math.cos %[[IMAG_I]] fastmath<nnan,contract> : f32
791+
// CHECK: %[[RES_REAL:.*]] = arith.mulf %[[EXP]], %[[COS]] fastmath<nnan,contract> : f32
792+
// CHECK: %[[SIN:.*]] = math.sin %[[IMAG_I]] fastmath<nnan,contract> : f32
793+
// CHECK: %[[RES_IMAG:.*]] = arith.mulf %[[EXP]], %[[SIN]] fastmath<nnan,contract> : f32
794+
// CHECK: %[[RES_EXP:.*]] = complex.create %[[RES_REAL]], %[[RES_IMAG]] : complex<f32>
795+
// CHECK: %[[REAL:.*]] = complex.re %[[RES_EXP]] : complex<f32>
796+
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
797+
// CHECK: %[[REAL_M1:.*]] = arith.subf %[[REAL]], %[[ONE]] fastmath<nnan,contract> : f32
798+
// CHECK: %[[IMAG:.*]] = complex.im %[[RES_EXP]] : complex<f32>
799+
// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
800+
// CHECK: return %[[RES]] : complex<f32>

0 commit comments

Comments
 (0)