Skip to content

[mlir][complex] Fastmath flag support for complex.tanh #88571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 14, 2024

Conversation

Lewuathe
Copy link
Member

@llvmbot
Copy link
Member

llvmbot commented Apr 12, 2024

@llvm/pr-subscribers-mlir

Author: Kai Sasaki (Lewuathe)

Changes

See https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981?u=lewuathe


Full diff: https://github.com/llvm/llvm-project/pull/88571.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+8-6)
  • (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+19)
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 9c82e8105f06e5..9dc146da7ee142 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -945,6 +945,7 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
     auto loc = op.getLoc();
     auto type = cast<ComplexType>(adaptor.getComplex().getType());
     auto elementType = cast<FloatType>(type.getElementType());
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
     // The hyperbolic tangent for complex number can be calculated as follows.
     // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y))
@@ -953,17 +954,18 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
     Value imag =
         rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
-    Value tanhA = rewriter.create<math::TanhOp>(loc, real);
-    Value cosB = rewriter.create<math::CosOp>(loc, imag);
-    Value sinB = rewriter.create<math::SinOp>(loc, imag);
-    Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB);
+    Value tanhA = rewriter.create<math::TanhOp>(loc, real, fmf);
+    Value cosB = rewriter.create<math::CosOp>(loc, imag, fmf);
+    Value sinB = rewriter.create<math::SinOp>(loc, imag, fmf);
+    Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB, fmf);
     Value numerator =
         rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB);
     Value one = rewriter.create<arith::ConstantOp>(
         loc, elementType, rewriter.getFloatAttr(elementType, 1));
-    Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB);
+    Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB, fmf);
     Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul);
-    rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator);
+    rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator,
+                                                fmf);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 8d2fb09daa87b6..5aec9260867f3f 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -2017,3 +2017,22 @@ func.func @complex_tan_with_fmf(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+
+// -----
+
+// CHECK-LABEL: func @complex_tanh_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_tanh_with_fmf(%arg: complex<f32>) -> complex<f32> {
+  %tanh = complex.tanh %arg fastmath<nnan,contract> : complex<f32>
+  return %tanh : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[TANH_A:.*]] = math.tanh %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[COS_B:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[SIN_B:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[TAN_B:.*]] = arith.divf %[[SIN_B]], %[[COS_B]] fastmath<nnan,contract> : f32
+// CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] fastmath<nnan,contract> : f32
+// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>
\ No newline at end of file

@Lewuathe Lewuathe requested review from tpopp and joker-eph April 12, 2024 19:59
@Lewuathe Lewuathe merged commit 8891fd5 into llvm:main Apr 14, 2024
@Lewuathe Lewuathe deleted the fastmathflag-complex-tanh branch April 14, 2024 10:52
bazuzi pushed a commit to bazuzi/llvm-project that referenced this pull request Apr 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants