Skip to content

[mlir][complex] Support Fastmath flag in conversion of complex.sqrt to standard #85019

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -845,21 +845,22 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
auto type = cast<ComplexType>(op.getType());
Type elementType = type.getElementType();
Value arg = adaptor.getComplex();
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();

Value zero =
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));

Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());

Value absLhs = b.create<math::AbsFOp>(real);
Value absArg = b.create<complex::AbsOp>(elementType, arg);
Value addAbs = b.create<arith::AddFOp>(absLhs, absArg);
Value absLhs = b.create<math::AbsFOp>(real, fmf);
Value absArg = b.create<complex::AbsOp>(elementType, arg, fmf);
Value addAbs = b.create<arith::AddFOp>(absLhs, absArg, fmf);

Value half = b.create<arith::ConstantOp>(elementType,
b.getFloatAttr(elementType, 0.5));
Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half);
Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs);
Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half, fmf);
Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs, fmf);

Value realIsNegative =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
Expand All @@ -869,7 +870,7 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
Value resultReal = sqrtAddAbs;

Value imagDivTwoResultReal = b.create<arith::DivFOp>(
imag, b.create<arith::AddFOp>(resultReal, resultReal));
imag, b.create<arith::AddFOp>(resultReal, resultReal, fmf), fmf);

Value negativeResultReal = b.create<arith::NegFOp>(resultReal);

Expand All @@ -882,7 +883,7 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
resultReal = b.create<arith::SelectOp>(
realIsNegative,
b.create<arith::DivFOp>(
imag, b.create<arith::AddFOp>(resultImag, resultImag)),
imag, b.create<arith::AddFOp>(resultImag, resultImag, fmf), fmf),
resultReal);

Value realIsZero =
Expand Down
146 changes: 126 additions & 20 deletions mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -708,11 +708,60 @@ func.func @complex_tanh(%arg: complex<f32>) -> complex<f32> {
// -----

// CHECK-LABEL: func @complex_sqrt
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_sqrt(%arg: complex<f32>) -> complex<f32> {
%sqrt = complex.sqrt %arg : complex<f32>
return %sqrt : complex<f32>
}

// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAR0:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[VAR1:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[VAR2:.*]] = math.absf %[[VAR0]] : f32
// CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[VAR3:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[VAR4:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[VAR5:.*]] = arith.cmpf oeq, %[[VAR3]], %[[CST0]] : f32
// CHECK: %[[VAR6:.*]] = arith.cmpf oeq, %[[VAR4]], %[[CST0]] : f32
// CHECK: %[[VAR7:.*]] = arith.divf %[[VAR4]], %[[VAR3]] : f32
// CHECK: %[[VAR8:.*]] = arith.mulf %[[VAR7]], %[[VAR7]] : f32
// CHECK: %[[VAR9:.*]] = arith.addf %[[VAR8]], %[[CST1]] : f32
// CHECK: %[[VAR10:.*]] = math.sqrt %[[VAR9]] : f32
// CHECK: %[[VAR11:.*]] = math.absf %[[VAR3]] : f32
// CHECK: %[[VAR12:.*]] = arith.mulf %[[VAR10]], %[[VAR11]] : f32
// CHECK: %[[VAR13:.*]] = arith.divf %[[VAR3]], %[[VAR4]] : f32
// CHECK: %[[VAR14:.*]] = arith.mulf %[[VAR13]], %[[VAR13]] : f32
// CHECK: %[[VAR15:.*]] = arith.addf %[[VAR14]], %[[CST1]] : f32
// CHECK: %[[VAR16:.*]] = math.sqrt %[[VAR15]] : f32
// CHECK: %[[VAR17:.*]] = math.absf %[[VAR4]] : f32
// CHECK: %[[VAR18:.*]] = arith.mulf %[[VAR16]], %[[VAR17]] : f32
// CHECK: %[[VAR19:.*]] = arith.cmpf ogt, %[[VAR3]], %[[VAR4]] : f32
// CHECK: %[[VAR20:.*]] = arith.select %[[VAR19]], %[[VAR12]], %[[VAR18]] : f32
// CHECK: %[[VAR21:.*]] = arith.select %[[VAR6]], %[[VAR11]], %[[VAR20]] : f32
// CHECK: %[[VAR22:.*]] = arith.select %[[VAR5]], %[[VAR17]], %[[VAR21]] : f32
// CHECK: %[[VAR23:.*]] = arith.addf %[[VAR2]], %[[VAR22]] : f32
// CHECK: %[[CST2:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[VAR24:.*]] = arith.mulf %[[VAR23]], %[[CST2]] : f32
// CHECK: %[[VAR25:.*]] = math.sqrt %[[VAR24]] : f32
// CHECK: %[[VAR26:.*]] = arith.cmpf olt, %[[VAR0]], %cst : f32
// CHECK: %[[VAR27:.*]] = arith.cmpf olt, %[[VAR1]], %cst : f32
// CHECK: %[[VAR28:.*]] = arith.addf %[[VAR25]], %[[VAR25]] : f32
// CHECK: %[[VAR29:.*]] = arith.divf %[[VAR1]], %[[VAR28]] : f32
// CHECK: %[[VAR30:.*]] = arith.negf %[[VAR25]] : f32
// CHECK: %[[VAR31:.*]] = arith.select %[[VAR27]], %[[VAR30]], %[[VAR25]] : f32
// CHECK: %[[VAR32:.*]] = arith.select %[[VAR26]], %[[VAR31]], %[[VAR29]] : f32
// CHECK: %[[VAR33:.*]] = arith.addf %[[VAR32]], %[[VAR32]] : f32
// CHECK: %[[VAR34:.*]] = arith.divf %[[VAR1]], %[[VAR33]] : f32
// CHECK: %[[VAR35:.*]] = arith.select %[[VAR26]], %[[VAR34]], %[[VAR25]] : f32
// CHECK: %[[VAR36:.*]] = arith.cmpf oeq, %[[VAR0]], %cst : f32
// CHECK: %[[VAR37:.*]] = arith.cmpf oeq, %[[VAR1]], %cst : f32
// CHECK: %[[VAR38:.*]] = arith.andi %[[VAR36]], %[[VAR37]] : i1
// CHECK: %[[VAR39:.*]] = arith.select %[[VAR38]], %cst, %[[VAR35]] : f32
// CHECK: %[[VAR40:.*]] = arith.select %[[VAR38]], %cst, %[[VAR32]] : f32
// CHECK: %[[VAR41:.*]] = complex.create %[[VAR39]], %[[VAR40]] : complex<f32>
// CHECK: return %[[VAR41]] : complex<f32>

// -----

// CHECK-LABEL: func @complex_conj
Expand Down Expand Up @@ -1254,42 +1303,42 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
// CHECK: %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAR187:.*]] = complex.re %[[VAR186]] : complex<f32>
// CHECK: %[[VAR188:.*]] = complex.im %[[VAR186]] : complex<f32>
// CHECK: %[[VAR189:.*]] = math.absf %[[VAR187]] : f32
// CHECK: %[[VAR189:.*]] = math.absf %[[VAR187]] fastmath<nnan,contract> : f32
// CHECK: %[[CST_7:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[CST_8:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[VAR190:.*]] = complex.re %[[VAR186]] : complex<f32>
// CHECK: %[[VAR191:.*]] = complex.im %[[VAR186]] : complex<f32>
// CHECK: %[[VAR192:.*]] = arith.cmpf oeq, %[[VAR190]], %[[CST_7]] : f32
// CHECK: %[[VAR193:.*]] = arith.cmpf oeq, %[[VAR191]], %[[CST_7]] : f32
// CHECK: %[[VAR194:.*]] = arith.divf %[[VAR191]], %[[VAR190]] : f32
// CHECK: %[[VAR195:.*]] = arith.mulf %[[VAR194]], %[[VAR194]] : f32
// CHECK: %[[VAR196:.*]] = arith.addf %[[VAR195]], %[[CST_8]] : f32
// CHECK: %[[VAR197:.*]] = math.sqrt %[[VAR196]] : f32
// CHECK: %[[VAR198:.*]] = math.absf %[[VAR190]] : f32
// CHECK: %[[VAR199:.*]] = arith.mulf %[[VAR197]], %[[VAR198]] : f32
// CHECK: %[[VAR200:.*]] = arith.divf %[[VAR190]], %[[VAR191]] : f32
// CHECK: %[[VAR201:.*]] = arith.mulf %[[VAR200]], %[[VAR200]] : f32
// CHECK: %[[VAR202:.*]] = arith.addf %[[VAR201]], %[[CST_8]] : f32
// CHECK: %[[VAR203:.*]] = math.sqrt %[[VAR202]] : f32
// CHECK: %[[VAR204:.*]] = math.absf %[[VAR191]] : f32
// CHECK: %[[VAR205:.*]] = arith.mulf %[[VAR203]], %[[VAR204]] : f32
// CHECK: %[[VAR194:.*]] = arith.divf %[[VAR191]], %[[VAR190]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR195:.*]] = arith.mulf %[[VAR194]], %[[VAR194]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR196:.*]] = arith.addf %[[VAR195]], %[[CST_8]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR197:.*]] = math.sqrt %[[VAR196]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR198:.*]] = math.absf %[[VAR190]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR199:.*]] = arith.mulf %[[VAR197]], %[[VAR198]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR200:.*]] = arith.divf %[[VAR190]], %[[VAR191]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR201:.*]] = arith.mulf %[[VAR200]], %[[VAR200]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR202:.*]] = arith.addf %[[VAR201]], %[[CST_8]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR203:.*]] = math.sqrt %[[VAR202]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR204:.*]] = math.absf %[[VAR191]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR205:.*]] = arith.mulf %[[VAR203]], %[[VAR204]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR206:.*]] = arith.cmpf ogt, %[[VAR190]], %[[VAR191]] : f32
// CHECK: %[[VAR207:.*]] = arith.select %[[VAR206]], %[[VAR199]], %[[VAR205]] : f32
// CHECK: %[[VAR208:.*]] = arith.select %[[VAR193]], %[[VAR198]], %[[VAR207]] : f32
// CHECK: %[[VAR209:.*]] = arith.select %[[VAR192]], %[[VAR204]], %[[VAR208]] : f32
// CHECK: %[[VAR210:.*]] = arith.addf %[[VAR189]], %[[VAR209]] : f32
// CHECK: %[[VAR210:.*]] = arith.addf %[[VAR189]], %[[VAR209]] fastmath<nnan,contract> : f32
// CHECK: %[[CST_9:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[VAR211:.*]] = arith.mulf %[[VAR210]], %[[CST_9]] : f32
// CHECK: %[[VAR212:.*]] = math.sqrt %[[VAR211]] : f32
// CHECK: %[[VAR211:.*]] = arith.mulf %[[VAR210]], %[[CST_9]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR212:.*]] = math.sqrt %[[VAR211]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR213:.*]] = arith.cmpf olt, %[[VAR187]], %[[CST_6]] : f32
// CHECK: %[[VAR214:.*]] = arith.cmpf olt, %[[VAR188]], %[[CST_6]] : f32
// CHECK: %[[VAR215:.*]] = arith.addf %[[VAR212]], %[[VAR212]] : f32
// CHECK: %[[VAR216:.*]] = arith.divf %[[VAR188]], %[[VAR215]] : f32
// CHECK: %[[VAR215:.*]] = arith.addf %[[VAR212]], %[[VAR212]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR216:.*]] = arith.divf %[[VAR188]], %[[VAR215]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR217:.*]] = arith.negf %[[VAR212]] : f32
// CHECK: %[[VAR218:.*]] = arith.select %[[VAR214]], %[[VAR217]], %[[VAR212]] : f32
// CHECK: %[[VAR219:.*]] = arith.select %[[VAR213]], %[[VAR218]], %[[VAR216]] : f32
// CHECK: %[[VAR220:.*]] = arith.addf %[[VAR219]], %[[VAR219]] : f32
// CHECK: %[[VAR221:.*]] = arith.divf %[[VAR188]], %[[VAR220]] : f32
// CHECK: %[[VAR220:.*]] = arith.addf %[[VAR219]], %[[VAR219]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR221:.*]] = arith.divf %[[VAR188]], %[[VAR220]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR222:.*]] = arith.select %[[VAR213]], %[[VAR221]], %[[VAR212]] : f32
// CHECK: %[[VAR223:.*]] = arith.cmpf oeq, %[[VAR187]], %[[CST_6]] : f32
// CHECK: %[[VAR224:.*]] = arith.cmpf oeq, %[[VAR188]], %[[CST_6]] : f32
Expand Down Expand Up @@ -1728,3 +1777,60 @@ func.func @complex_div_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> compl
// CHECK: %[[RESULT_IMAG_WITH_SPECIAL_CASES:.*]] = arith.select %[[RESULT_IS_NAN]], %[[RESULT_IMAG_SPECIAL_CASE_1]], %[[RESULT_IMAG]] : f32
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>

// -----

// CHECK-LABEL: func @complex_sqrt_with_fmf
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_sqrt_with_fmf(%arg: complex<f32>) -> complex<f32> {
%sqrt = complex.sqrt %arg fastmath<nnan,contract> : complex<f32>
return %sqrt : complex<f32>
}

// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAR0:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[VAR1:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[VAR2:.*]] = math.absf %[[VAR0]] fastmath<nnan,contract> : f32
// CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[VAR3:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[VAR4:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[VAR5:.*]] = arith.cmpf oeq, %[[VAR3]], %[[CST0]] : f32
// CHECK: %[[VAR6:.*]] = arith.cmpf oeq, %[[VAR4]], %[[CST0]] : f32
// CHECK: %[[VAR7:.*]] = arith.divf %[[VAR4]], %[[VAR3]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR8:.*]] = arith.mulf %[[VAR7]], %[[VAR7]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR9:.*]] = arith.addf %[[VAR8]], %[[CST1]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR10:.*]] = math.sqrt %[[VAR9]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR11:.*]] = math.absf %[[VAR3]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR12:.*]] = arith.mulf %[[VAR10]], %[[VAR11]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR13:.*]] = arith.divf %[[VAR3]], %[[VAR4]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR14:.*]] = arith.mulf %[[VAR13]], %[[VAR13]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR15:.*]] = arith.addf %[[VAR14]], %[[CST1]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR16:.*]] = math.sqrt %[[VAR15]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR17:.*]] = math.absf %[[VAR4]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR18:.*]] = arith.mulf %[[VAR16]], %[[VAR17]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR19:.*]] = arith.cmpf ogt, %[[VAR3]], %[[VAR4]] : f32
// CHECK: %[[VAR20:.*]] = arith.select %[[VAR19]], %[[VAR12]], %[[VAR18]] : f32
// CHECK: %[[VAR21:.*]] = arith.select %[[VAR6]], %[[VAR11]], %[[VAR20]] : f32
// CHECK: %[[VAR22:.*]] = arith.select %[[VAR5]], %[[VAR17]], %[[VAR21]] : f32
// CHECK: %[[VAR23:.*]] = arith.addf %[[VAR2]], %[[VAR22]] fastmath<nnan,contract> : f32
// CHECK: %[[CST2:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[VAR24:.*]] = arith.mulf %[[VAR23]], %[[CST2]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR25:.*]] = math.sqrt %[[VAR24]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR26:.*]] = arith.cmpf olt, %[[VAR0]], %cst : f32
// CHECK: %[[VAR27:.*]] = arith.cmpf olt, %[[VAR1]], %cst : f32
// CHECK: %[[VAR28:.*]] = arith.addf %[[VAR25]], %[[VAR25]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR29:.*]] = arith.divf %[[VAR1]], %[[VAR28]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR30:.*]] = arith.negf %[[VAR25]] : f32
// CHECK: %[[VAR31:.*]] = arith.select %[[VAR27]], %[[VAR30]], %[[VAR25]] : f32
// CHECK: %[[VAR32:.*]] = arith.select %[[VAR26]], %[[VAR31]], %[[VAR29]] : f32
// CHECK: %[[VAR33:.*]] = arith.addf %[[VAR32]], %[[VAR32]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR34:.*]] = arith.divf %[[VAR1]], %[[VAR33]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR35:.*]] = arith.select %[[VAR26]], %[[VAR34]], %[[VAR25]] : f32
// CHECK: %[[VAR36:.*]] = arith.cmpf oeq, %[[VAR0]], %cst : f32
// CHECK: %[[VAR37:.*]] = arith.cmpf oeq, %[[VAR1]], %cst : f32
// CHECK: %[[VAR38:.*]] = arith.andi %[[VAR36]], %[[VAR37]] : i1
// CHECK: %[[VAR39:.*]] = arith.select %[[VAR38]], %cst, %[[VAR35]] : f32
// CHECK: %[[VAR40:.*]] = arith.select %[[VAR38]], %cst, %[[VAR32]] : f32
// CHECK: %[[VAR41:.*]] = complex.create %[[VAR39]], %[[VAR40]] : complex<f32>
// CHECK: return %[[VAR41]] : complex<f32>