Skip to content

Commit c9ae553

Browse files
committed
[mlir][complex] Fastmath flag support for complex.tanh
1 parent b45c9c3 commit c9ae553

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,7 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
945945
auto loc = op.getLoc();
946946
auto type = cast<ComplexType>(adaptor.getComplex().getType());
947947
auto elementType = cast<FloatType>(type.getElementType());
948+
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
948949

949950
// The hyperbolic tangent for complex number can be calculated as follows.
950951
// tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y))
@@ -953,17 +954,18 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
953954
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
954955
Value imag =
955956
rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
956-
Value tanhA = rewriter.create<math::TanhOp>(loc, real);
957-
Value cosB = rewriter.create<math::CosOp>(loc, imag);
958-
Value sinB = rewriter.create<math::SinOp>(loc, imag);
959-
Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB);
957+
Value tanhA = rewriter.create<math::TanhOp>(loc, real, fmf);
958+
Value cosB = rewriter.create<math::CosOp>(loc, imag, fmf);
959+
Value sinB = rewriter.create<math::SinOp>(loc, imag, fmf);
960+
Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB, fmf);
960961
Value numerator =
961962
rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB);
962963
Value one = rewriter.create<arith::ConstantOp>(
963964
loc, elementType, rewriter.getFloatAttr(elementType, 1));
964-
Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB);
965+
Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB, fmf);
965966
Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul);
966-
rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator);
967+
rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator,
968+
fmf);
967969
return success();
968970
}
969971
};

mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2017,3 +2017,22 @@ func.func @complex_tan_with_fmf(%arg: complex<f32>) -> complex<f32> {
20172017
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
20182018
// CHECK: return %[[RESULT]] : complex<f32>
20192019

2020+
2021+
// -----
2022+
2023+
// CHECK-LABEL: func @complex_tanh_with_fmf
2024+
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
2025+
func.func @complex_tanh_with_fmf(%arg: complex<f32>) -> complex<f32> {
2026+
%tanh = complex.tanh %arg fastmath<nnan,contract> : complex<f32>
2027+
return %tanh : complex<f32>
2028+
}
2029+
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
2030+
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
2031+
// CHECK: %[[TANH_A:.*]] = math.tanh %[[REAL]] fastmath<nnan,contract> : f32
2032+
// CHECK: %[[COS_B:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
2033+
// CHECK: %[[SIN_B:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
2034+
// CHECK: %[[TAN_B:.*]] = arith.divf %[[SIN_B]], %[[COS_B]] fastmath<nnan,contract> : f32
2035+
// CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex<f32>
2036+
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
2037+
// CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] fastmath<nnan,contract> : f32
2038+
// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>

0 commit comments

Comments
 (0)