@@ -499,13 +499,16 @@ struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
499
499
ConversionPatternRewriter &rewriter) const override {
500
500
auto type = cast<ComplexType>(adaptor.getComplex ().getType ());
501
501
auto elementType = cast<FloatType>(type.getElementType ());
502
+ arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr ();
502
503
mlir::ImplicitLocOpBuilder b (op.getLoc (), rewriter);
503
504
504
- Value abs = b.create <complex::AbsOp>(elementType, adaptor.getComplex ());
505
- Value resultReal = b.create <math::LogOp>(elementType, abs);
505
+ Value abs = b.create <complex::AbsOp>(elementType, adaptor.getComplex (),
506
+ fmf.getValue ());
507
+ Value resultReal = b.create <math::LogOp>(elementType, abs, fmf.getValue ());
506
508
Value real = b.create <complex::ReOp>(elementType, adaptor.getComplex ());
507
509
Value imag = b.create <complex::ImOp>(elementType, adaptor.getComplex ());
508
- Value resultImag = b.create <math::Atan2Op>(elementType, imag, real);
510
+ Value resultImag =
511
+ b.create <math::Atan2Op>(elementType, imag, real, fmf.getValue ());
509
512
rewriter.replaceOpWithNewOp <complex::CreateOp>(op, type, resultReal,
510
513
resultImag);
511
514
return success ();
@@ -520,6 +523,7 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
520
523
ConversionPatternRewriter &rewriter) const override {
521
524
auto type = cast<ComplexType>(adaptor.getComplex ().getType ());
522
525
auto elementType = cast<FloatType>(type.getElementType ());
526
+ arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr ();
523
527
mlir::ImplicitLocOpBuilder b (op.getLoc (), rewriter);
524
528
525
529
Value real = b.create <complex::ReOp>(elementType, adaptor.getComplex ());
@@ -535,15 +539,21 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
535
539
// log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
536
540
// log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1)
537
541
// log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1)
538
- Value sumSq = b.create <arith::MulFOp>(real, real);
539
- sumSq = b.create <arith::AddFOp>(sumSq, b.create <arith::MulFOp>(real, two));
540
- sumSq = b.create <arith::AddFOp>(sumSq, b.create <arith::MulFOp>(imag, imag));
541
- Value logSumSq = b.create <math::Log1pOp>(elementType, sumSq);
542
- Value resultReal = b.create <arith::MulFOp>(logSumSq, half);
542
+ Value sumSq = b.create <arith::MulFOp>(real, real, fmf.getValue ());
543
+ sumSq = b.create <arith::AddFOp>(
544
+ sumSq, b.create <arith::MulFOp>(real, two, fmf.getValue ()),
545
+ fmf.getValue ());
546
+ sumSq = b.create <arith::AddFOp>(
547
+ sumSq, b.create <arith::MulFOp>(imag, imag, fmf.getValue ()),
548
+ fmf.getValue ());
549
+ Value logSumSq =
550
+ b.create <math::Log1pOp>(elementType, sumSq, fmf.getValue ());
551
+ Value resultReal = b.create <arith::MulFOp>(logSumSq, half, fmf.getValue ());
552
+
553
+ Value realPlusOne = b.create <arith::AddFOp>(real, one, fmf.getValue ());
543
554
544
- Value realPlusOne = b.create <arith::AddFOp>(real, one);
545
-
546
- Value resultImag = b.create <math::Atan2Op>(elementType, imag, realPlusOne);
555
+ Value resultImag =
556
+ b.create <math::Atan2Op>(elementType, imag, realPlusOne, fmf.getValue ());
547
557
rewriter.replaceOpWithNewOp <complex::CreateOp>(op, type, resultReal,
548
558
resultImag);
549
559
return success ();
0 commit comments