Skip to content

Commit 81bc4ac

Browse files
committed
[mlir][complex] Support Fastmath flag in conversion of complex.div to standard
1 parent 330af6e commit 81bc4ac

File tree

2 files changed

+167
-42
lines changed

2 files changed

+167
-42
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 55 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
195195
auto loc = op.getLoc();
196196
auto type = cast<ComplexType>(adaptor.getComplex().getType());
197197
auto elementType = cast<FloatType>(type.getElementType());
198+
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
198199

199200
Value real =
200201
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
@@ -206,11 +207,13 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
206207
// implementation in the subclass to combine them.
207208
Value half = rewriter.create<arith::ConstantOp>(
208209
loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
209-
Value exp = rewriter.create<math::ExpOp>(loc, imag);
210-
Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp);
211-
Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp);
212-
Value sin = rewriter.create<math::SinOp>(loc, real);
213-
Value cos = rewriter.create<math::CosOp>(loc, real);
210+
Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf.getValue());
211+
Value scaledExp =
212+
rewriter.create<arith::MulFOp>(loc, half, exp, fmf.getValue());
213+
Value reciprocalExp =
214+
rewriter.create<arith::DivFOp>(loc, half, exp, fmf.getValue());
215+
Value sin = rewriter.create<math::SinOp>(loc, real, fmf.getValue());
216+
Value cos = rewriter.create<math::CosOp>(loc, real, fmf.getValue());
214217

215218
auto resultPair =
216219
combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter);
@@ -257,6 +260,7 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
257260
auto loc = op.getLoc();
258261
auto type = cast<ComplexType>(adaptor.getLhs().getType());
259262
auto elementType = cast<FloatType>(type.getElementType());
263+
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
260264

261265
Value lhsReal =
262266
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
@@ -290,45 +294,51 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
290294
//
291295
// See https://dl.acm.org/citation.cfm?id=368661 for more details.
292296
Value rhsRealImagRatio =
293-
rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag);
297+
rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag, fmf);
294298
Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
295299
loc, rhsImag,
296-
rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal));
300+
rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal, fmf),
301+
fmf);
297302
Value realNumerator1 = rewriter.create<arith::AddFOp>(
298-
loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio),
299-
lhsImag);
300-
Value resultReal1 =
301-
rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom);
303+
loc,
304+
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio, fmf),
305+
lhsImag, fmf);
306+
Value resultReal1 = rewriter.create<arith::DivFOp>(loc, realNumerator1,
307+
rhsRealImagDenom, fmf);
302308
Value imagNumerator1 = rewriter.create<arith::SubFOp>(
303-
loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio),
304-
lhsReal);
305-
Value resultImag1 =
306-
rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
309+
loc,
310+
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio, fmf),
311+
lhsReal, fmf);
312+
Value resultImag1 = rewriter.create<arith::DivFOp>(loc, imagNumerator1,
313+
rhsRealImagDenom, fmf);
307314

308315
Value rhsImagRealRatio =
309-
rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal);
316+
rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal, fmf);
310317
Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
311318
loc, rhsReal,
312-
rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag));
319+
rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag, fmf),
320+
fmf);
313321
Value realNumerator2 = rewriter.create<arith::AddFOp>(
314322
loc, lhsReal,
315-
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio));
316-
Value resultReal2 =
317-
rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom);
323+
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio, fmf),
324+
fmf);
325+
Value resultReal2 = rewriter.create<arith::DivFOp>(loc, realNumerator2,
326+
rhsImagRealDenom, fmf);
318327
Value imagNumerator2 = rewriter.create<arith::SubFOp>(
319328
loc, lhsImag,
320-
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio));
321-
Value resultImag2 =
322-
rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
329+
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio, fmf),
330+
fmf);
331+
Value resultImag2 = rewriter.create<arith::DivFOp>(loc, imagNumerator2,
332+
rhsImagRealDenom, fmf);
323333

324334
// Consider corner cases.
325335
// Case 1. Zero denominator, numerator contains at most one NaN value.
326336
Value zero = rewriter.create<arith::ConstantOp>(
327337
loc, elementType, rewriter.getZeroAttr(elementType));
328-
Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsReal);
338+
Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsReal, fmf);
329339
Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
330340
loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
331-
Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsImag);
341+
Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsImag, fmf);
332342
Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
333343
loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
334344
Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
@@ -347,9 +357,9 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
347357
Value infWithSignOfRhsReal =
348358
rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
349359
Value infinityResultReal =
350-
rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
360+
rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal, fmf);
351361
Value infinityResultImag =
352-
rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
362+
rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag, fmf);
353363

354364
// Case 2. Infinite numerator, finite denominator.
355365
Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
@@ -358,10 +368,10 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
358368
loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
359369
Value rhsFinite =
360370
rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
361-
Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsReal);
371+
Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsReal, fmf);
362372
Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
363373
loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
364-
Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsImag);
374+
Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsImag, fmf);
365375
Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
366376
loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
367377
Value lhsInfinite =
@@ -377,21 +387,23 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
377387
loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
378388
lhsImag);
379389
Value lhsRealIsInfWithSignTimesRhsReal =
380-
rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
390+
rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal, fmf);
381391
Value lhsImagIsInfWithSignTimesRhsImag =
382-
rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
392+
rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag, fmf);
383393
Value resultReal3 = rewriter.create<arith::MulFOp>(
384394
loc, inf,
385395
rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
386-
lhsImagIsInfWithSignTimesRhsImag));
396+
lhsImagIsInfWithSignTimesRhsImag, fmf),
397+
fmf);
387398
Value lhsRealIsInfWithSignTimesRhsImag =
388-
rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
399+
rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag, fmf);
389400
Value lhsImagIsInfWithSignTimesRhsReal =
390-
rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
401+
rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal, fmf);
391402
Value resultImag3 = rewriter.create<arith::MulFOp>(
392403
loc, inf,
393404
rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
394-
lhsRealIsInfWithSignTimesRhsImag));
405+
lhsRealIsInfWithSignTimesRhsImag, fmf),
406+
fmf);
395407

396408
// Case 3: Finite numerator, infinite denominator.
397409
Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
@@ -415,21 +427,23 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
415427
loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
416428
rhsImag);
417429
Value rhsRealIsInfWithSignTimesLhsReal =
418-
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
430+
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign, fmf);
419431
Value rhsImagIsInfWithSignTimesLhsImag =
420-
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
432+
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign, fmf);
421433
Value resultReal4 = rewriter.create<arith::MulFOp>(
422434
loc, zero,
423435
rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
424-
rhsImagIsInfWithSignTimesLhsImag));
436+
rhsImagIsInfWithSignTimesLhsImag, fmf),
437+
fmf);
425438
Value rhsRealIsInfWithSignTimesLhsImag =
426-
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
439+
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign, fmf);
427440
Value rhsImagIsInfWithSignTimesLhsReal =
428-
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
441+
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign, fmf);
429442
Value resultImag4 = rewriter.create<arith::MulFOp>(
430443
loc, zero,
431444
rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
432-
rhsImagIsInfWithSignTimesLhsReal));
445+
rhsImagIsInfWithSignTimesLhsReal, fmf),
446+
fmf);
433447

434448
Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
435449
loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);

0 commit comments

Comments
 (0)