Skip to content

[mlir][complex] Support Fastmath flag in conversion of complex.div to standard #82729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 47 additions & 36 deletions mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
auto loc = op.getLoc();
auto type = cast<ComplexType>(adaptor.getLhs().getType());
auto elementType = cast<FloatType>(type.getElementType());
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();

Value lhsReal =
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
Expand Down Expand Up @@ -290,45 +291,51 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
//
// See https://dl.acm.org/citation.cfm?id=368661 for more details.
Value rhsRealImagRatio =
rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag);
rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag, fmf);
Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
loc, rhsImag,
rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal));
rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal, fmf),
fmf);
Value realNumerator1 = rewriter.create<arith::AddFOp>(
loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio),
lhsImag);
Value resultReal1 =
rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom);
loc,
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio, fmf),
lhsImag, fmf);
Value resultReal1 = rewriter.create<arith::DivFOp>(loc, realNumerator1,
rhsRealImagDenom, fmf);
Value imagNumerator1 = rewriter.create<arith::SubFOp>(
loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio),
lhsReal);
Value resultImag1 =
rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
loc,
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio, fmf),
lhsReal, fmf);
Value resultImag1 = rewriter.create<arith::DivFOp>(loc, imagNumerator1,
rhsRealImagDenom, fmf);

Value rhsImagRealRatio =
rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal);
rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal, fmf);
Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
loc, rhsReal,
rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag));
rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag, fmf),
fmf);
Value realNumerator2 = rewriter.create<arith::AddFOp>(
loc, lhsReal,
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio));
Value resultReal2 =
rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom);
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio, fmf),
fmf);
Value resultReal2 = rewriter.create<arith::DivFOp>(loc, realNumerator2,
rhsImagRealDenom, fmf);
Value imagNumerator2 = rewriter.create<arith::SubFOp>(
loc, lhsImag,
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio));
Value resultImag2 =
rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio, fmf),
fmf);
Value resultImag2 = rewriter.create<arith::DivFOp>(loc, imagNumerator2,
rhsImagRealDenom, fmf);

// Consider corner cases.
// Case 1. Zero denominator, numerator contains at most one NaN value.
Value zero = rewriter.create<arith::ConstantOp>(
loc, elementType, rewriter.getZeroAttr(elementType));
Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsReal);
Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsReal, fmf);
Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsImag);
Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsImag, fmf);
Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
Expand All @@ -347,9 +354,9 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
Value infWithSignOfRhsReal =
rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
Value infinityResultReal =
rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal, fmf);
Value infinityResultImag =
rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag, fmf);

// Case 2. Infinite numerator, finite denominator.
Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
Expand All @@ -358,10 +365,10 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
Value rhsFinite =
rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsReal);
Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsReal, fmf);
Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsImag);
Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsImag, fmf);
Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
Value lhsInfinite =
Expand All @@ -377,21 +384,23 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
lhsImag);
Value lhsRealIsInfWithSignTimesRhsReal =
rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal, fmf);
Value lhsImagIsInfWithSignTimesRhsImag =
rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag, fmf);
Value resultReal3 = rewriter.create<arith::MulFOp>(
loc, inf,
rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
lhsImagIsInfWithSignTimesRhsImag));
lhsImagIsInfWithSignTimesRhsImag, fmf),
fmf);
Value lhsRealIsInfWithSignTimesRhsImag =
rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag, fmf);
Value lhsImagIsInfWithSignTimesRhsReal =
rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal, fmf);
Value resultImag3 = rewriter.create<arith::MulFOp>(
loc, inf,
rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
lhsRealIsInfWithSignTimesRhsImag));
lhsRealIsInfWithSignTimesRhsImag, fmf),
fmf);

// Case 3: Finite numerator, infinite denominator.
Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
Expand All @@ -415,21 +424,23 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
rhsImag);
Value rhsRealIsInfWithSignTimesLhsReal =
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign, fmf);
Value rhsImagIsInfWithSignTimesLhsImag =
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign, fmf);
Value resultReal4 = rewriter.create<arith::MulFOp>(
loc, zero,
rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
rhsImagIsInfWithSignTimesLhsImag));
rhsImagIsInfWithSignTimesLhsImag, fmf),
fmf);
Value rhsRealIsInfWithSignTimesLhsImag =
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign, fmf);
Value rhsImagIsInfWithSignTimesLhsReal =
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign, fmf);
Value resultImag4 = rewriter.create<arith::MulFOp>(
loc, zero,
rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
rhsImagIsInfWithSignTimesLhsReal));
rhsImagIsInfWithSignTimesLhsReal, fmf),
fmf);

Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
Expand Down
Loading