Skip to content

Commit 288d317

Browse files
authored
[mlir][complex] Support Fastmath flag in conversion of complex.div to standard (#82729)
Support Fastmath flag to convert `complex.div` to standard dialects. See: https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981
1 parent e421c12 commit 288d317

File tree

2 files changed

+159
-37
lines changed

2 files changed

+159
-37
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
257257
auto loc = op.getLoc();
258258
auto type = cast<ComplexType>(adaptor.getLhs().getType());
259259
auto elementType = cast<FloatType>(type.getElementType());
260+
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
260261

261262
Value lhsReal =
262263
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
@@ -290,45 +291,51 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
290291
//
291292
// See https://dl.acm.org/citation.cfm?id=368661 for more details.
292293
Value rhsRealImagRatio =
293-
rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag);
294+
rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag, fmf);
294295
Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
295296
loc, rhsImag,
296-
rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal));
297+
rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal, fmf),
298+
fmf);
297299
Value realNumerator1 = rewriter.create<arith::AddFOp>(
298-
loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio),
299-
lhsImag);
300-
Value resultReal1 =
301-
rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom);
300+
loc,
301+
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio, fmf),
302+
lhsImag, fmf);
303+
Value resultReal1 = rewriter.create<arith::DivFOp>(loc, realNumerator1,
304+
rhsRealImagDenom, fmf);
302305
Value imagNumerator1 = rewriter.create<arith::SubFOp>(
303-
loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio),
304-
lhsReal);
305-
Value resultImag1 =
306-
rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
306+
loc,
307+
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio, fmf),
308+
lhsReal, fmf);
309+
Value resultImag1 = rewriter.create<arith::DivFOp>(loc, imagNumerator1,
310+
rhsRealImagDenom, fmf);
307311

308312
Value rhsImagRealRatio =
309-
rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal);
313+
rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal, fmf);
310314
Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
311315
loc, rhsReal,
312-
rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag));
316+
rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag, fmf),
317+
fmf);
313318
Value realNumerator2 = rewriter.create<arith::AddFOp>(
314319
loc, lhsReal,
315-
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio));
316-
Value resultReal2 =
317-
rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom);
320+
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio, fmf),
321+
fmf);
322+
Value resultReal2 = rewriter.create<arith::DivFOp>(loc, realNumerator2,
323+
rhsImagRealDenom, fmf);
318324
Value imagNumerator2 = rewriter.create<arith::SubFOp>(
319325
loc, lhsImag,
320-
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio));
321-
Value resultImag2 =
322-
rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
326+
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio, fmf),
327+
fmf);
328+
Value resultImag2 = rewriter.create<arith::DivFOp>(loc, imagNumerator2,
329+
rhsImagRealDenom, fmf);
323330

324331
// Consider corner cases.
325332
// Case 1. Zero denominator, numerator contains at most one NaN value.
326333
Value zero = rewriter.create<arith::ConstantOp>(
327334
loc, elementType, rewriter.getZeroAttr(elementType));
328-
Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsReal);
335+
Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsReal, fmf);
329336
Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
330337
loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
331-
Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsImag);
338+
Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsImag, fmf);
332339
Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
333340
loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
334341
Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
@@ -347,9 +354,9 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
347354
Value infWithSignOfRhsReal =
348355
rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
349356
Value infinityResultReal =
350-
rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
357+
rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal, fmf);
351358
Value infinityResultImag =
352-
rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
359+
rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag, fmf);
353360

354361
// Case 2. Infinite numerator, finite denominator.
355362
Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
@@ -358,10 +365,10 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
358365
loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
359366
Value rhsFinite =
360367
rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
361-
Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsReal);
368+
Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsReal, fmf);
362369
Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
363370
loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
364-
Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsImag);
371+
Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsImag, fmf);
365372
Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
366373
loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
367374
Value lhsInfinite =
@@ -377,21 +384,23 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
377384
loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
378385
lhsImag);
379386
Value lhsRealIsInfWithSignTimesRhsReal =
380-
rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
387+
rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal, fmf);
381388
Value lhsImagIsInfWithSignTimesRhsImag =
382-
rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
389+
rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag, fmf);
383390
Value resultReal3 = rewriter.create<arith::MulFOp>(
384391
loc, inf,
385392
rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
386-
lhsImagIsInfWithSignTimesRhsImag));
393+
lhsImagIsInfWithSignTimesRhsImag, fmf),
394+
fmf);
387395
Value lhsRealIsInfWithSignTimesRhsImag =
388-
rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
396+
rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag, fmf);
389397
Value lhsImagIsInfWithSignTimesRhsReal =
390-
rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
398+
rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal, fmf);
391399
Value resultImag3 = rewriter.create<arith::MulFOp>(
392400
loc, inf,
393401
rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
394-
lhsRealIsInfWithSignTimesRhsImag));
402+
lhsRealIsInfWithSignTimesRhsImag, fmf),
403+
fmf);
395404

396405
// Case 3: Finite numerator, infinite denominator.
397406
Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
@@ -415,21 +424,23 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
415424
loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
416425
rhsImag);
417426
Value rhsRealIsInfWithSignTimesLhsReal =
418-
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
427+
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign, fmf);
419428
Value rhsImagIsInfWithSignTimesLhsImag =
420-
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
429+
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign, fmf);
421430
Value resultReal4 = rewriter.create<arith::MulFOp>(
422431
loc, zero,
423432
rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
424-
rhsImagIsInfWithSignTimesLhsImag));
433+
rhsImagIsInfWithSignTimesLhsImag, fmf),
434+
fmf);
425435
Value rhsRealIsInfWithSignTimesLhsImag =
426-
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
436+
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign, fmf);
427437
Value rhsImagIsInfWithSignTimesLhsReal =
428-
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
438+
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign, fmf);
429439
Value resultImag4 = rewriter.create<arith::MulFOp>(
430440
loc, zero,
431441
rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
432-
rhsImagIsInfWithSignTimesLhsReal));
442+
rhsImagIsInfWithSignTimesLhsReal, fmf),
443+
fmf);
433444

434445
Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
435446
loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);

0 commit comments

Comments
 (0)