Skip to content

Commit 8aaa2cb

Browse files
authored
[mlir][complex] Support Fastmath flag for complex log ops (#69798)
Progressive support of fastmath flag in the conversion of log type ops. See more detail https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981
1 parent ac7c816 commit 8aaa2cb

File tree

2 files changed

+69
-12
lines changed

2 files changed

+69
-12
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -499,13 +499,16 @@ struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
499499
ConversionPatternRewriter &rewriter) const override {
500500
auto type = cast<ComplexType>(adaptor.getComplex().getType());
501501
auto elementType = cast<FloatType>(type.getElementType());
502+
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
502503
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
503504

504-
Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
505-
Value resultReal = b.create<math::LogOp>(elementType, abs);
505+
Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(),
506+
fmf.getValue());
507+
Value resultReal = b.create<math::LogOp>(elementType, abs, fmf.getValue());
506508
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
507509
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
508-
Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
510+
Value resultImag =
511+
b.create<math::Atan2Op>(elementType, imag, real, fmf.getValue());
509512
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
510513
resultImag);
511514
return success();
@@ -520,6 +523,7 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
520523
ConversionPatternRewriter &rewriter) const override {
521524
auto type = cast<ComplexType>(adaptor.getComplex().getType());
522525
auto elementType = cast<FloatType>(type.getElementType());
526+
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
523527
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
524528

525529
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
@@ -535,15 +539,21 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
535539
// log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
536540
// log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1)
537541
// log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1)
538-
Value sumSq = b.create<arith::MulFOp>(real, real);
539-
sumSq = b.create<arith::AddFOp>(sumSq, b.create<arith::MulFOp>(real, two));
540-
sumSq = b.create<arith::AddFOp>(sumSq, b.create<arith::MulFOp>(imag, imag));
541-
Value logSumSq = b.create<math::Log1pOp>(elementType, sumSq);
542-
Value resultReal = b.create<arith::MulFOp>(logSumSq, half);
542+
Value sumSq = b.create<arith::MulFOp>(real, real, fmf.getValue());
543+
sumSq = b.create<arith::AddFOp>(
544+
sumSq, b.create<arith::MulFOp>(real, two, fmf.getValue()),
545+
fmf.getValue());
546+
sumSq = b.create<arith::AddFOp>(
547+
sumSq, b.create<arith::MulFOp>(imag, imag, fmf.getValue()),
548+
fmf.getValue());
549+
Value logSumSq =
550+
b.create<math::Log1pOp>(elementType, sumSq, fmf.getValue());
551+
Value resultReal = b.create<arith::MulFOp>(logSumSq, half, fmf.getValue());
552+
553+
Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf.getValue());
543554

544-
Value realPlusOne = b.create<arith::AddFOp>(real, one);
545-
546-
Value resultImag = b.create<math::Atan2Op>(elementType, imag, realPlusOne);
555+
Value resultImag =
556+
b.create<math::Atan2Op>(elementType, imag, realPlusOne, fmf.getValue());
547557
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
548558
resultImag);
549559
return success();

mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -797,4 +797,51 @@ func.func @complex_expm1_with_fmf(%arg: complex<f32>) -> complex<f32> {
797797
// CHECK: %[[REAL_M1:.*]] = arith.subf %[[REAL]], %[[ONE]] fastmath<nnan,contract> : f32
798798
// CHECK: %[[IMAG:.*]] = complex.im %[[RES_EXP]] : complex<f32>
799799
// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
800-
// CHECK: return %[[RES]] : complex<f32>
800+
// CHECK: return %[[RES]] : complex<f32>
801+
802+
// -----
803+
804+
// CHECK-LABEL: func @complex_log_with_fmf
805+
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
806+
func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
807+
%log = complex.log %arg fastmath<nnan,contract> : complex<f32>
808+
return %log : complex<f32>
809+
}
810+
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
811+
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
812+
// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
813+
// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
814+
// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] fastmath<nnan,contract> : f32
815+
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
816+
// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] fastmath<nnan,contract> : f32
817+
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
818+
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
819+
// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG2]], %[[REAL2]] fastmath<nnan,contract> : f32
820+
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
821+
// CHECK: return %[[RESULT]] : complex<f32>
822+
823+
// -----
824+
825+
// CHECK-LABEL: func @complex_log1p_with_fmf
826+
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
827+
func.func @complex_log1p_with_fmf(%arg: complex<f32>) -> complex<f32> {
828+
%log1p = complex.log1p %arg fastmath<nnan,contract> : complex<f32>
829+
return %log1p : complex<f32>
830+
}
831+
832+
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
833+
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
834+
// CHECK: %[[ONE_HALF:.*]] = arith.constant 5.000000e-01 : f32
835+
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
836+
// CHECK: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32
837+
// CHECK: %[[SQ_SUM_0:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
838+
// CHECK: %[[TWO_REAL:.*]] = arith.mulf %[[REAL]], %[[TWO]] fastmath<nnan,contract> : f32
839+
// CHECK: %[[SQ_SUM_1:.*]] = arith.addf %[[SQ_SUM_0]], %[[TWO_REAL]] fastmath<nnan,contract> : f32
840+
// CHECK: %[[SQ_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
841+
// CHECK: %[[SQ_SUM_2:.*]] = arith.addf %[[SQ_SUM_1]], %[[SQ_IMAG]] fastmath<nnan,contract> : f32
842+
// CHECK: %[[LOG_SQ_SUM:.*]] = math.log1p %[[SQ_SUM_2]] fastmath<nnan,contract> : f32
843+
// CHECK: %[[RESULT_REAL:.*]] = arith.mulf %[[LOG_SQ_SUM]], %[[ONE_HALF]] fastmath<nnan,contract> : f32
844+
// CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] fastmath<nnan,contract> : f32
845+
// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] fastmath<nnan,contract> : f32
846+
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
847+
// CHECK: return %[[RESULT]] : complex<f32>

0 commit comments

Comments
 (0)