Skip to content

Commit 0a81f0b

Browse files
committed
[mlir][complex] Support Fastmath flag in conversion of complex.div to standard
1 parent 0e8d187 commit 0a81f0b

File tree

2 files changed

+191
-52
lines changed

2 files changed

+191
-52
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 79 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
195195
auto loc = op.getLoc();
196196
auto type = cast<ComplexType>(adaptor.getComplex().getType());
197197
auto elementType = cast<FloatType>(type.getElementType());
198+
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
198199

199200
Value real =
200201
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
@@ -206,11 +207,13 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
206207
// implementation in the subclass to combine them.
207208
Value half = rewriter.create<arith::ConstantOp>(
208209
loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
209-
Value exp = rewriter.create<math::ExpOp>(loc, imag);
210-
Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp);
211-
Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp);
212-
Value sin = rewriter.create<math::SinOp>(loc, real);
213-
Value cos = rewriter.create<math::CosOp>(loc, real);
210+
Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf.getValue());
211+
Value scaledExp =
212+
rewriter.create<arith::MulFOp>(loc, half, exp, fmf.getValue());
213+
Value reciprocalExp =
214+
rewriter.create<arith::DivFOp>(loc, half, exp, fmf.getValue());
215+
Value sin = rewriter.create<math::SinOp>(loc, real, fmf.getValue());
216+
Value cos = rewriter.create<math::CosOp>(loc, real, fmf.getValue());
214217

215218
auto resultPair =
216219
combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter);
@@ -257,6 +260,7 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
257260
auto loc = op.getLoc();
258261
auto type = cast<ComplexType>(adaptor.getLhs().getType());
259262
auto elementType = cast<FloatType>(type.getElementType());
263+
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
260264

261265
Value lhsReal =
262266
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
@@ -290,45 +294,59 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
290294
//
291295
// See https://dl.acm.org/citation.cfm?id=368661 for more details.
292296
Value rhsRealImagRatio =
293-
rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag);
297+
rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag, fmf.getValue());
294298
Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
295299
loc, rhsImag,
296-
rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal));
300+
rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal,
301+
fmf.getValue()),
302+
fmf.getValue());
297303
Value realNumerator1 = rewriter.create<arith::AddFOp>(
298-
loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio),
299-
lhsImag);
300-
Value resultReal1 =
301-
rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom);
304+
loc,
305+
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio,
306+
fmf.getValue()),
307+
lhsImag, fmf.getValue());
308+
Value resultReal1 = rewriter.create<arith::DivFOp>(
309+
loc, realNumerator1, rhsRealImagDenom, fmf.getValue());
302310
Value imagNumerator1 = rewriter.create<arith::SubFOp>(
303-
loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio),
304-
lhsReal);
305-
Value resultImag1 =
306-
rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
311+
loc,
312+
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio,
313+
fmf.getValue()),
314+
lhsReal, fmf.getValue());
315+
Value resultImag1 = rewriter.create<arith::DivFOp>(
316+
loc, imagNumerator1, rhsRealImagDenom, fmf.getValue());
307317

308318
Value rhsImagRealRatio =
309-
rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal);
319+
rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal, fmf.getValue());
310320
Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
311321
loc, rhsReal,
312-
rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag));
322+
rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag,
323+
fmf.getValue()),
324+
fmf.getValue());
313325
Value realNumerator2 = rewriter.create<arith::AddFOp>(
314326
loc, lhsReal,
315-
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio));
316-
Value resultReal2 =
317-
rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom);
327+
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio,
328+
fmf.getValue()),
329+
fmf.getValue());
330+
Value resultReal2 = rewriter.create<arith::DivFOp>(
331+
loc, realNumerator2, rhsImagRealDenom, fmf.getValue());
318332
Value imagNumerator2 = rewriter.create<arith::SubFOp>(
319333
loc, lhsImag,
320-
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio));
321-
Value resultImag2 =
322-
rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
334+
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio,
335+
fmf.getValue()),
336+
fmf.getValue());
337+
Value resultImag2 = rewriter.create<arith::DivFOp>(
338+
loc, imagNumerator2, rhsImagRealDenom, fmf.getValue());
323339

324340
// Consider corner cases.
325341
// Case 1. Zero denominator, numerator contains at most one NaN value.
326342
Value zero = rewriter.create<arith::ConstantOp>(
327343
loc, elementType, rewriter.getZeroAttr(elementType));
328-
Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsReal);
344+
Value rhsRealAbs =
345+
rewriter.create<math::AbsFOp>(loc, rhsReal, fmf.getValue());
329346
Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
330347
loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
331-
Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsImag);
348+
Value rhsImagAbs =
349+
rewriter.create<math::AbsFOp>(loc, rhsImag, fmf.getValue());
332350
Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
333351
loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
334352
Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
@@ -346,10 +364,10 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
346364
elementType, APFloat::getInf(elementType.getFloatSemantics())));
347365
Value infWithSignOfRhsReal =
348366
rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
349-
Value infinityResultReal =
350-
rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
351-
Value infinityResultImag =
352-
rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
367+
Value infinityResultReal = rewriter.create<arith::MulFOp>(
368+
loc, infWithSignOfRhsReal, lhsReal, fmf.getValue());
369+
Value infinityResultImag = rewriter.create<arith::MulFOp>(
370+
loc, infWithSignOfRhsReal, lhsImag, fmf.getValue());
353371

354372
// Case 2. Infinite numerator, finite denominator.
355373
Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
@@ -358,10 +376,12 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
358376
loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
359377
Value rhsFinite =
360378
rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
361-
Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsReal);
379+
Value lhsRealAbs =
380+
rewriter.create<math::AbsFOp>(loc, lhsReal, fmf.getValue());
362381
Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
363382
loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
364-
Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsImag);
383+
Value lhsImagAbs =
384+
rewriter.create<math::AbsFOp>(loc, lhsImag, fmf.getValue());
365385
Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
366386
loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
367387
Value lhsInfinite =
@@ -376,22 +396,26 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
376396
Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
377397
loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
378398
lhsImag);
379-
Value lhsRealIsInfWithSignTimesRhsReal =
380-
rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
381-
Value lhsImagIsInfWithSignTimesRhsImag =
382-
rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
399+
Value lhsRealIsInfWithSignTimesRhsReal = rewriter.create<arith::MulFOp>(
400+
loc, lhsRealIsInfWithSign, rhsReal, fmf.getValue());
401+
Value lhsImagIsInfWithSignTimesRhsImag = rewriter.create<arith::MulFOp>(
402+
loc, lhsImagIsInfWithSign, rhsImag, fmf.getValue());
383403
Value resultReal3 = rewriter.create<arith::MulFOp>(
384404
loc, inf,
385405
rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
386-
lhsImagIsInfWithSignTimesRhsImag));
387-
Value lhsRealIsInfWithSignTimesRhsImag =
388-
rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
389-
Value lhsImagIsInfWithSignTimesRhsReal =
390-
rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
406+
lhsImagIsInfWithSignTimesRhsImag,
407+
fmf.getValue()),
408+
fmf.getValue());
409+
Value lhsRealIsInfWithSignTimesRhsImag = rewriter.create<arith::MulFOp>(
410+
loc, lhsRealIsInfWithSign, rhsImag, fmf.getValue());
411+
Value lhsImagIsInfWithSignTimesRhsReal = rewriter.create<arith::MulFOp>(
412+
loc, lhsImagIsInfWithSign, rhsReal, fmf.getValue());
391413
Value resultImag3 = rewriter.create<arith::MulFOp>(
392414
loc, inf,
393415
rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
394-
lhsRealIsInfWithSignTimesRhsImag));
416+
lhsRealIsInfWithSignTimesRhsImag,
417+
fmf.getValue()),
418+
fmf.getValue());
395419

396420
// Case 3: Finite numerator, infinite denominator.
397421
Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
@@ -414,22 +438,26 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
414438
Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
415439
loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
416440
rhsImag);
417-
Value rhsRealIsInfWithSignTimesLhsReal =
418-
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
419-
Value rhsImagIsInfWithSignTimesLhsImag =
420-
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
441+
Value rhsRealIsInfWithSignTimesLhsReal = rewriter.create<arith::MulFOp>(
442+
loc, lhsReal, rhsRealIsInfWithSign, fmf.getValue());
443+
Value rhsImagIsInfWithSignTimesLhsImag = rewriter.create<arith::MulFOp>(
444+
loc, lhsImag, rhsImagIsInfWithSign, fmf.getValue());
421445
Value resultReal4 = rewriter.create<arith::MulFOp>(
422446
loc, zero,
423447
rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
424-
rhsImagIsInfWithSignTimesLhsImag));
425-
Value rhsRealIsInfWithSignTimesLhsImag =
426-
rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
427-
Value rhsImagIsInfWithSignTimesLhsReal =
428-
rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
448+
rhsImagIsInfWithSignTimesLhsImag,
449+
fmf.getValue()),
450+
fmf.getValue());
451+
Value rhsRealIsInfWithSignTimesLhsImag = rewriter.create<arith::MulFOp>(
452+
loc, lhsImag, rhsRealIsInfWithSign, fmf.getValue());
453+
Value rhsImagIsInfWithSignTimesLhsReal = rewriter.create<arith::MulFOp>(
454+
loc, lhsReal, rhsImagIsInfWithSign, fmf.getValue());
429455
Value resultImag4 = rewriter.create<arith::MulFOp>(
430456
loc, zero,
431457
rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
432-
rhsImagIsInfWithSignTimesLhsReal));
458+
rhsImagIsInfWithSignTimesLhsReal,
459+
fmf.getValue()),
460+
fmf.getValue());
433461

434462
Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
435463
loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);

0 commit comments

Comments
 (0)