@@ -195,6 +195,7 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
195
195
auto loc = op.getLoc ();
196
196
auto type = cast<ComplexType>(adaptor.getComplex ().getType ());
197
197
auto elementType = cast<FloatType>(type.getElementType ());
198
+ arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr ();
198
199
199
200
Value real =
200
201
rewriter.create <complex::ReOp>(loc, elementType, adaptor.getComplex ());
@@ -206,11 +207,13 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
206
207
// implementation in the subclass to combine them.
207
208
Value half = rewriter.create <arith::ConstantOp>(
208
209
loc, elementType, rewriter.getFloatAttr (elementType, 0.5 ));
209
- Value exp = rewriter.create <math::ExpOp>(loc, imag);
210
- Value scaledExp = rewriter.create <arith::MulFOp>(loc, half, exp);
211
- Value reciprocalExp = rewriter.create <arith::DivFOp>(loc, half, exp);
212
- Value sin = rewriter.create <math::SinOp>(loc, real);
213
- Value cos = rewriter.create <math::CosOp>(loc, real);
210
+ Value exp = rewriter.create <math::ExpOp>(loc, imag, fmf.getValue ());
211
+ Value scaledExp =
212
+ rewriter.create <arith::MulFOp>(loc, half, exp, fmf.getValue ());
213
+ Value reciprocalExp =
214
+ rewriter.create <arith::DivFOp>(loc, half, exp, fmf.getValue ());
215
+ Value sin = rewriter.create <math::SinOp>(loc, real, fmf.getValue ());
216
+ Value cos = rewriter.create <math::CosOp>(loc, real, fmf.getValue ());
214
217
215
218
auto resultPair =
216
219
combine (loc, scaledExp, reciprocalExp, sin, cos, rewriter);
@@ -257,6 +260,7 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
257
260
auto loc = op.getLoc ();
258
261
auto type = cast<ComplexType>(adaptor.getLhs ().getType ());
259
262
auto elementType = cast<FloatType>(type.getElementType ());
263
+ arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr ();
260
264
261
265
Value lhsReal =
262
266
rewriter.create <complex::ReOp>(loc, elementType, adaptor.getLhs ());
@@ -290,45 +294,51 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
290
294
//
291
295
// See https://dl.acm.org/citation.cfm?id=368661 for more details.
292
296
Value rhsRealImagRatio =
293
- rewriter.create <arith::DivFOp>(loc, rhsReal, rhsImag);
297
+ rewriter.create <arith::DivFOp>(loc, rhsReal, rhsImag, fmf );
294
298
Value rhsRealImagDenom = rewriter.create <arith::AddFOp>(
295
299
loc, rhsImag,
296
- rewriter.create <arith::MulFOp>(loc, rhsRealImagRatio, rhsReal));
300
+ rewriter.create <arith::MulFOp>(loc, rhsRealImagRatio, rhsReal, fmf),
301
+ fmf);
297
302
Value realNumerator1 = rewriter.create <arith::AddFOp>(
298
- loc, rewriter.create <arith::MulFOp>(loc, lhsReal, rhsRealImagRatio),
299
- lhsImag);
300
- Value resultReal1 =
301
- rewriter.create <arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom);
303
+ loc,
304
+ rewriter.create <arith::MulFOp>(loc, lhsReal, rhsRealImagRatio, fmf),
305
+ lhsImag, fmf);
306
+ Value resultReal1 = rewriter.create <arith::DivFOp>(loc, realNumerator1,
307
+ rhsRealImagDenom, fmf);
302
308
Value imagNumerator1 = rewriter.create <arith::SubFOp>(
303
- loc, rewriter.create <arith::MulFOp>(loc, lhsImag, rhsRealImagRatio),
304
- lhsReal);
305
- Value resultImag1 =
306
- rewriter.create <arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
309
+ loc,
310
+ rewriter.create <arith::MulFOp>(loc, lhsImag, rhsRealImagRatio, fmf),
311
+ lhsReal, fmf);
312
+ Value resultImag1 = rewriter.create <arith::DivFOp>(loc, imagNumerator1,
313
+ rhsRealImagDenom, fmf);
307
314
308
315
Value rhsImagRealRatio =
309
- rewriter.create <arith::DivFOp>(loc, rhsImag, rhsReal);
316
+ rewriter.create <arith::DivFOp>(loc, rhsImag, rhsReal, fmf );
310
317
Value rhsImagRealDenom = rewriter.create <arith::AddFOp>(
311
318
loc, rhsReal,
312
- rewriter.create <arith::MulFOp>(loc, rhsImagRealRatio, rhsImag));
319
+ rewriter.create <arith::MulFOp>(loc, rhsImagRealRatio, rhsImag, fmf),
320
+ fmf);
313
321
Value realNumerator2 = rewriter.create <arith::AddFOp>(
314
322
loc, lhsReal,
315
- rewriter.create <arith::MulFOp>(loc, lhsImag, rhsImagRealRatio));
316
- Value resultReal2 =
317
- rewriter.create <arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom);
323
+ rewriter.create <arith::MulFOp>(loc, lhsImag, rhsImagRealRatio, fmf),
324
+ fmf);
325
+ Value resultReal2 = rewriter.create <arith::DivFOp>(loc, realNumerator2,
326
+ rhsImagRealDenom, fmf);
318
327
Value imagNumerator2 = rewriter.create <arith::SubFOp>(
319
328
loc, lhsImag,
320
- rewriter.create <arith::MulFOp>(loc, lhsReal, rhsImagRealRatio));
321
- Value resultImag2 =
322
- rewriter.create <arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
329
+ rewriter.create <arith::MulFOp>(loc, lhsReal, rhsImagRealRatio, fmf),
330
+ fmf);
331
+ Value resultImag2 = rewriter.create <arith::DivFOp>(loc, imagNumerator2,
332
+ rhsImagRealDenom, fmf);
323
333
324
334
// Consider corner cases.
325
335
// Case 1. Zero denominator, numerator contains at most one NaN value.
326
336
Value zero = rewriter.create <arith::ConstantOp>(
327
337
loc, elementType, rewriter.getZeroAttr (elementType));
328
- Value rhsRealAbs = rewriter.create <math::AbsFOp>(loc, rhsReal);
338
+ Value rhsRealAbs = rewriter.create <math::AbsFOp>(loc, rhsReal, fmf );
329
339
Value rhsRealIsZero = rewriter.create <arith::CmpFOp>(
330
340
loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
331
- Value rhsImagAbs = rewriter.create <math::AbsFOp>(loc, rhsImag);
341
+ Value rhsImagAbs = rewriter.create <math::AbsFOp>(loc, rhsImag, fmf );
332
342
Value rhsImagIsZero = rewriter.create <arith::CmpFOp>(
333
343
loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
334
344
Value lhsRealIsNotNaN = rewriter.create <arith::CmpFOp>(
@@ -347,9 +357,9 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
347
357
Value infWithSignOfRhsReal =
348
358
rewriter.create <math::CopySignOp>(loc, inf, rhsReal);
349
359
Value infinityResultReal =
350
- rewriter.create <arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
360
+ rewriter.create <arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal, fmf );
351
361
Value infinityResultImag =
352
- rewriter.create <arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
362
+ rewriter.create <arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag, fmf );
353
363
354
364
// Case 2. Infinite numerator, finite denominator.
355
365
Value rhsRealFinite = rewriter.create <arith::CmpFOp>(
@@ -358,10 +368,10 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
358
368
loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
359
369
Value rhsFinite =
360
370
rewriter.create <arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
361
- Value lhsRealAbs = rewriter.create <math::AbsFOp>(loc, lhsReal);
371
+ Value lhsRealAbs = rewriter.create <math::AbsFOp>(loc, lhsReal, fmf );
362
372
Value lhsRealInfinite = rewriter.create <arith::CmpFOp>(
363
373
loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
364
- Value lhsImagAbs = rewriter.create <math::AbsFOp>(loc, lhsImag);
374
+ Value lhsImagAbs = rewriter.create <math::AbsFOp>(loc, lhsImag, fmf );
365
375
Value lhsImagInfinite = rewriter.create <arith::CmpFOp>(
366
376
loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
367
377
Value lhsInfinite =
@@ -377,21 +387,23 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
377
387
loc, rewriter.create <arith::SelectOp>(loc, lhsImagInfinite, one, zero),
378
388
lhsImag);
379
389
Value lhsRealIsInfWithSignTimesRhsReal =
380
- rewriter.create <arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
390
+ rewriter.create <arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal, fmf );
381
391
Value lhsImagIsInfWithSignTimesRhsImag =
382
- rewriter.create <arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
392
+ rewriter.create <arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag, fmf );
383
393
Value resultReal3 = rewriter.create <arith::MulFOp>(
384
394
loc, inf,
385
395
rewriter.create <arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
386
- lhsImagIsInfWithSignTimesRhsImag));
396
+ lhsImagIsInfWithSignTimesRhsImag, fmf),
397
+ fmf);
387
398
Value lhsRealIsInfWithSignTimesRhsImag =
388
- rewriter.create <arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
399
+ rewriter.create <arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag, fmf );
389
400
Value lhsImagIsInfWithSignTimesRhsReal =
390
- rewriter.create <arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
401
+ rewriter.create <arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal, fmf );
391
402
Value resultImag3 = rewriter.create <arith::MulFOp>(
392
403
loc, inf,
393
404
rewriter.create <arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
394
- lhsRealIsInfWithSignTimesRhsImag));
405
+ lhsRealIsInfWithSignTimesRhsImag, fmf),
406
+ fmf);
395
407
396
408
// Case 3: Finite numerator, infinite denominator.
397
409
Value lhsRealFinite = rewriter.create <arith::CmpFOp>(
@@ -415,21 +427,23 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
415
427
loc, rewriter.create <arith::SelectOp>(loc, rhsImagInfinite, one, zero),
416
428
rhsImag);
417
429
Value rhsRealIsInfWithSignTimesLhsReal =
418
- rewriter.create <arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
430
+ rewriter.create <arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign, fmf );
419
431
Value rhsImagIsInfWithSignTimesLhsImag =
420
- rewriter.create <arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
432
+ rewriter.create <arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign, fmf );
421
433
Value resultReal4 = rewriter.create <arith::MulFOp>(
422
434
loc, zero,
423
435
rewriter.create <arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
424
- rhsImagIsInfWithSignTimesLhsImag));
436
+ rhsImagIsInfWithSignTimesLhsImag, fmf),
437
+ fmf);
425
438
Value rhsRealIsInfWithSignTimesLhsImag =
426
- rewriter.create <arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
439
+ rewriter.create <arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign, fmf );
427
440
Value rhsImagIsInfWithSignTimesLhsReal =
428
- rewriter.create <arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
441
+ rewriter.create <arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign, fmf );
429
442
Value resultImag4 = rewriter.create <arith::MulFOp>(
430
443
loc, zero,
431
444
rewriter.create <arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
432
- rhsImagIsInfWithSignTimesLhsReal));
445
+ rhsImagIsInfWithSignTimesLhsReal, fmf),
446
+ fmf);
433
447
434
448
Value realAbsSmallerThanImagAbs = rewriter.create <arith::CmpFOp>(
435
449
loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
0 commit comments