Skip to content

Commit 8891fd5

Browse files
authored
[mlir][complex] Fastmath flag support for complex.tanh (#88571)
1 parent ef164ce commit 8891fd5

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,7 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
973973
auto loc = op.getLoc();
974974
auto type = cast<ComplexType>(adaptor.getComplex().getType());
975975
auto elementType = cast<FloatType>(type.getElementType());
976+
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
976977

977978
// The hyperbolic tangent for complex number can be calculated as follows.
978979
// tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y))
@@ -981,17 +982,18 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
981982
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
982983
Value imag =
983984
rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
984-
Value tanhA = rewriter.create<math::TanhOp>(loc, real);
985-
Value cosB = rewriter.create<math::CosOp>(loc, imag);
986-
Value sinB = rewriter.create<math::SinOp>(loc, imag);
987-
Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB);
985+
Value tanhA = rewriter.create<math::TanhOp>(loc, real, fmf);
986+
Value cosB = rewriter.create<math::CosOp>(loc, imag, fmf);
987+
Value sinB = rewriter.create<math::SinOp>(loc, imag, fmf);
988+
Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB, fmf);
988989
Value numerator =
989990
rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB);
990991
Value one = rewriter.create<arith::ConstantOp>(
991992
loc, elementType, rewriter.getFloatAttr(elementType, 1));
992-
Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB);
993+
Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB, fmf);
993994
Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul);
994-
rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator);
995+
rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator,
996+
fmf);
995997
return success();
996998
}
997999
};

mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2085,3 +2085,22 @@ func.func @complex_tan_with_fmf(%arg: complex<f32>) -> complex<f32> {
20852085
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
20862086
// CHECK: return %[[RESULT]] : complex<f32>
20872087

2088+
2089+
// -----
2090+
2091+
// CHECK-LABEL: func @complex_tanh_with_fmf
2092+
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
2093+
func.func @complex_tanh_with_fmf(%arg: complex<f32>) -> complex<f32> {
2094+
%tanh = complex.tanh %arg fastmath<nnan,contract> : complex<f32>
2095+
return %tanh : complex<f32>
2096+
}
2097+
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
2098+
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
2099+
// CHECK: %[[TANH_A:.*]] = math.tanh %[[REAL]] fastmath<nnan,contract> : f32
2100+
// CHECK: %[[COS_B:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
2101+
// CHECK: %[[SIN_B:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
2102+
// CHECK: %[[TAN_B:.*]] = arith.divf %[[SIN_B]], %[[COS_B]] fastmath<nnan,contract> : f32
2103+
// CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex<f32>
2104+
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
2105+
// CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] fastmath<nnan,contract> : f32
2106+
// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>

0 commit comments

Comments
 (0)