Skip to content

[mlir][complex] Fastmath flag for complex angle #88658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 17, 2024

Conversation

Lewuathe
Copy link
Member

@llvmbot llvmbot added the mlir label Apr 14, 2024
@llvmbot
Copy link
Member

llvmbot commented Apr 14, 2024

@llvm/pr-subscribers-mlir

Author: Kai Sasaki (Lewuathe)

Changes

See https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981


Full diff: https://github.com/llvm/llvm-project/pull/88658.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+10-7)
  • (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+32)
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 9c82e8105f06e5..ed266a45294410 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -945,6 +945,7 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
     auto loc = op.getLoc();
     auto type = cast<ComplexType>(adaptor.getComplex().getType());
     auto elementType = cast<FloatType>(type.getElementType());
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
     // The hyperbolic tangent for complex number can be calculated as follows.
     // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y))
@@ -953,17 +954,18 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
     Value imag =
         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
-    Value tanhA = rewriter.create<math::TanhOp>(loc, real);
-    Value cosB = rewriter.create<math::CosOp>(loc, imag);
-    Value sinB = rewriter.create<math::SinOp>(loc, imag);
-    Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB);
+    Value tanhA = rewriter.create<math::TanhOp>(loc, real, fmf);
+    Value cosB = rewriter.create<math::CosOp>(loc, imag, fmf);
+    Value sinB = rewriter.create<math::SinOp>(loc, imag, fmf);
+    Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB, fmf);
     Value numerator =
         rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB);
     Value one = rewriter.create<arith::ConstantOp>(
         loc, elementType, rewriter.getFloatAttr(elementType, 1));
-    Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB);
+    Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB, fmf);
     Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul);
-    rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator);
+    rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator,
+                                                fmf);
     return success();
   }
 };
@@ -1141,13 +1143,14 @@ struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
     auto type = op.getType();
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
     Value real =
         rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
     Value imag =
         rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
 
-    rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real);
+    rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf);
 
     return success();
   }
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 8d2fb09daa87b6..53b1876f033121 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -2017,3 +2017,35 @@ func.func @complex_tan_with_fmf(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+
+// -----
+
+// CHECK-LABEL: func @complex_tanh_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_tanh_with_fmf(%arg: complex<f32>) -> complex<f32> {
+  %tanh = complex.tanh %arg fastmath<nnan,contract> : complex<f32>
+  return %tanh : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[TANH_A:.*]] = math.tanh %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[COS_B:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[SIN_B:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[TAN_B:.*]] = arith.divf %[[SIN_B]], %[[COS_B]] fastmath<nnan,contract> : f32
+// CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] fastmath<nnan,contract> : f32
+// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL:   func.func @complex_angle_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_angle_with_fmf(%arg: complex<f32>) -> f32 {
+  %angle = complex.angle %arg fastmath<nnan,contract> : complex<f32>
+  return %angle : f32
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[RESULT:.*]] = math.atan2 %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: return %[[RESULT]] : f32
\ No newline at end of file

@Lewuathe Lewuathe changed the title Fastmathflag complex angle [mlir][complex] Fastmath flag for complex angle Apr 15, 2024
@Lewuathe Lewuathe force-pushed the fastmathflag-complex-angle branch from e202a21 to 05053db Compare April 15, 2024 02:15
@Lewuathe Lewuathe force-pushed the fastmathflag-complex-angle branch from 4151528 to b398bf9 Compare April 16, 2024 00:47
@Lewuathe Lewuathe merged commit 8c9d814 into llvm:main Apr 17, 2024
@Lewuathe Lewuathe deleted the fastmathflag-complex-angle branch April 17, 2024 00:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants