@@ -195,6 +195,7 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
195
195
auto loc = op.getLoc ();
196
196
auto type = cast<ComplexType>(adaptor.getComplex ().getType ());
197
197
auto elementType = cast<FloatType>(type.getElementType ());
198
+ arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr ();
198
199
199
200
Value real =
200
201
rewriter.create <complex::ReOp>(loc, elementType, adaptor.getComplex ());
@@ -206,11 +207,13 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
206
207
// implementation in the subclass to combine them.
207
208
Value half = rewriter.create <arith::ConstantOp>(
208
209
loc, elementType, rewriter.getFloatAttr (elementType, 0.5 ));
209
- Value exp = rewriter.create <math::ExpOp>(loc, imag);
210
- Value scaledExp = rewriter.create <arith::MulFOp>(loc, half, exp);
211
- Value reciprocalExp = rewriter.create <arith::DivFOp>(loc, half, exp);
212
- Value sin = rewriter.create <math::SinOp>(loc, real);
213
- Value cos = rewriter.create <math::CosOp>(loc, real);
210
+ Value exp = rewriter.create <math::ExpOp>(loc, imag, fmf.getValue ());
211
+ Value scaledExp =
212
+ rewriter.create <arith::MulFOp>(loc, half, exp, fmf.getValue ());
213
+ Value reciprocalExp =
214
+ rewriter.create <arith::DivFOp>(loc, half, exp, fmf.getValue ());
215
+ Value sin = rewriter.create <math::SinOp>(loc, real, fmf.getValue ());
216
+ Value cos = rewriter.create <math::CosOp>(loc, real, fmf.getValue ());
214
217
215
218
auto resultPair =
216
219
combine (loc, scaledExp, reciprocalExp, sin, cos, rewriter);
@@ -257,6 +260,7 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
257
260
auto loc = op.getLoc ();
258
261
auto type = cast<ComplexType>(adaptor.getLhs ().getType ());
259
262
auto elementType = cast<FloatType>(type.getElementType ());
263
+ arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr ();
260
264
261
265
Value lhsReal =
262
266
rewriter.create <complex::ReOp>(loc, elementType, adaptor.getLhs ());
@@ -290,45 +294,59 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
290
294
//
291
295
// See https://dl.acm.org/citation.cfm?id=368661 for more details.
292
296
Value rhsRealImagRatio =
293
- rewriter.create <arith::DivFOp>(loc, rhsReal, rhsImag);
297
+ rewriter.create <arith::DivFOp>(loc, rhsReal, rhsImag, fmf. getValue () );
294
298
Value rhsRealImagDenom = rewriter.create <arith::AddFOp>(
295
299
loc, rhsImag,
296
- rewriter.create <arith::MulFOp>(loc, rhsRealImagRatio, rhsReal));
300
+ rewriter.create <arith::MulFOp>(loc, rhsRealImagRatio, rhsReal,
301
+ fmf.getValue ()),
302
+ fmf.getValue ());
297
303
Value realNumerator1 = rewriter.create <arith::AddFOp>(
298
- loc, rewriter.create <arith::MulFOp>(loc, lhsReal, rhsRealImagRatio),
299
- lhsImag);
300
- Value resultReal1 =
301
- rewriter.create <arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom);
304
+ loc,
305
+ rewriter.create <arith::MulFOp>(loc, lhsReal, rhsRealImagRatio,
306
+ fmf.getValue ()),
307
+ lhsImag, fmf.getValue ());
308
+ Value resultReal1 = rewriter.create <arith::DivFOp>(
309
+ loc, realNumerator1, rhsRealImagDenom, fmf.getValue ());
302
310
Value imagNumerator1 = rewriter.create <arith::SubFOp>(
303
- loc, rewriter.create <arith::MulFOp>(loc, lhsImag, rhsRealImagRatio),
304
- lhsReal);
305
- Value resultImag1 =
306
- rewriter.create <arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
311
+ loc,
312
+ rewriter.create <arith::MulFOp>(loc, lhsImag, rhsRealImagRatio,
313
+ fmf.getValue ()),
314
+ lhsReal, fmf.getValue ());
315
+ Value resultImag1 = rewriter.create <arith::DivFOp>(
316
+ loc, imagNumerator1, rhsRealImagDenom, fmf.getValue ());
307
317
308
318
Value rhsImagRealRatio =
309
- rewriter.create <arith::DivFOp>(loc, rhsImag, rhsReal);
319
+ rewriter.create <arith::DivFOp>(loc, rhsImag, rhsReal, fmf. getValue () );
310
320
Value rhsImagRealDenom = rewriter.create <arith::AddFOp>(
311
321
loc, rhsReal,
312
- rewriter.create <arith::MulFOp>(loc, rhsImagRealRatio, rhsImag));
322
+ rewriter.create <arith::MulFOp>(loc, rhsImagRealRatio, rhsImag,
323
+ fmf.getValue ()),
324
+ fmf.getValue ());
313
325
Value realNumerator2 = rewriter.create <arith::AddFOp>(
314
326
loc, lhsReal,
315
- rewriter.create <arith::MulFOp>(loc, lhsImag, rhsImagRealRatio));
316
- Value resultReal2 =
317
- rewriter.create <arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom);
327
+ rewriter.create <arith::MulFOp>(loc, lhsImag, rhsImagRealRatio,
328
+ fmf.getValue ()),
329
+ fmf.getValue ());
330
+ Value resultReal2 = rewriter.create <arith::DivFOp>(
331
+ loc, realNumerator2, rhsImagRealDenom, fmf.getValue ());
318
332
Value imagNumerator2 = rewriter.create <arith::SubFOp>(
319
333
loc, lhsImag,
320
- rewriter.create <arith::MulFOp>(loc, lhsReal, rhsImagRealRatio));
321
- Value resultImag2 =
322
- rewriter.create <arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
334
+ rewriter.create <arith::MulFOp>(loc, lhsReal, rhsImagRealRatio,
335
+ fmf.getValue ()),
336
+ fmf.getValue ());
337
+ Value resultImag2 = rewriter.create <arith::DivFOp>(
338
+ loc, imagNumerator2, rhsImagRealDenom, fmf.getValue ());
323
339
324
340
// Consider corner cases.
325
341
// Case 1. Zero denominator, numerator contains at most one NaN value.
326
342
Value zero = rewriter.create <arith::ConstantOp>(
327
343
loc, elementType, rewriter.getZeroAttr (elementType));
328
- Value rhsRealAbs = rewriter.create <math::AbsFOp>(loc, rhsReal);
344
+ Value rhsRealAbs =
345
+ rewriter.create <math::AbsFOp>(loc, rhsReal, fmf.getValue ());
329
346
Value rhsRealIsZero = rewriter.create <arith::CmpFOp>(
330
347
loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
331
- Value rhsImagAbs = rewriter.create <math::AbsFOp>(loc, rhsImag);
348
+ Value rhsImagAbs =
349
+ rewriter.create <math::AbsFOp>(loc, rhsImag, fmf.getValue ());
332
350
Value rhsImagIsZero = rewriter.create <arith::CmpFOp>(
333
351
loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
334
352
Value lhsRealIsNotNaN = rewriter.create <arith::CmpFOp>(
@@ -346,10 +364,10 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
346
364
elementType, APFloat::getInf (elementType.getFloatSemantics ())));
347
365
Value infWithSignOfRhsReal =
348
366
rewriter.create <math::CopySignOp>(loc, inf, rhsReal);
349
- Value infinityResultReal =
350
- rewriter. create <arith::MulFOp>( loc, infWithSignOfRhsReal, lhsReal);
351
- Value infinityResultImag =
352
- rewriter. create <arith::MulFOp>( loc, infWithSignOfRhsReal, lhsImag);
367
+ Value infinityResultReal = rewriter. create <arith::MulFOp>(
368
+ loc, infWithSignOfRhsReal, lhsReal, fmf. getValue () );
369
+ Value infinityResultImag = rewriter. create <arith::MulFOp>(
370
+ loc, infWithSignOfRhsReal, lhsImag, fmf. getValue () );
353
371
354
372
// Case 2. Infinite numerator, finite denominator.
355
373
Value rhsRealFinite = rewriter.create <arith::CmpFOp>(
@@ -358,10 +376,12 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
358
376
loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
359
377
Value rhsFinite =
360
378
rewriter.create <arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
361
- Value lhsRealAbs = rewriter.create <math::AbsFOp>(loc, lhsReal);
379
+ Value lhsRealAbs =
380
+ rewriter.create <math::AbsFOp>(loc, lhsReal, fmf.getValue ());
362
381
Value lhsRealInfinite = rewriter.create <arith::CmpFOp>(
363
382
loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
364
- Value lhsImagAbs = rewriter.create <math::AbsFOp>(loc, lhsImag);
383
+ Value lhsImagAbs =
384
+ rewriter.create <math::AbsFOp>(loc, lhsImag, fmf.getValue ());
365
385
Value lhsImagInfinite = rewriter.create <arith::CmpFOp>(
366
386
loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
367
387
Value lhsInfinite =
@@ -376,22 +396,26 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
376
396
Value lhsImagIsInfWithSign = rewriter.create <math::CopySignOp>(
377
397
loc, rewriter.create <arith::SelectOp>(loc, lhsImagInfinite, one, zero),
378
398
lhsImag);
379
- Value lhsRealIsInfWithSignTimesRhsReal =
380
- rewriter. create <arith::MulFOp>( loc, lhsRealIsInfWithSign, rhsReal);
381
- Value lhsImagIsInfWithSignTimesRhsImag =
382
- rewriter. create <arith::MulFOp>( loc, lhsImagIsInfWithSign, rhsImag);
399
+ Value lhsRealIsInfWithSignTimesRhsReal = rewriter. create <arith::MulFOp>(
400
+ loc, lhsRealIsInfWithSign, rhsReal, fmf. getValue () );
401
+ Value lhsImagIsInfWithSignTimesRhsImag = rewriter. create <arith::MulFOp>(
402
+ loc, lhsImagIsInfWithSign, rhsImag, fmf. getValue () );
383
403
Value resultReal3 = rewriter.create <arith::MulFOp>(
384
404
loc, inf,
385
405
rewriter.create <arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
386
- lhsImagIsInfWithSignTimesRhsImag));
387
- Value lhsRealIsInfWithSignTimesRhsImag =
388
- rewriter.create <arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
389
- Value lhsImagIsInfWithSignTimesRhsReal =
390
- rewriter.create <arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
406
+ lhsImagIsInfWithSignTimesRhsImag,
407
+ fmf.getValue ()),
408
+ fmf.getValue ());
409
+ Value lhsRealIsInfWithSignTimesRhsImag = rewriter.create <arith::MulFOp>(
410
+ loc, lhsRealIsInfWithSign, rhsImag, fmf.getValue ());
411
+ Value lhsImagIsInfWithSignTimesRhsReal = rewriter.create <arith::MulFOp>(
412
+ loc, lhsImagIsInfWithSign, rhsReal, fmf.getValue ());
391
413
Value resultImag3 = rewriter.create <arith::MulFOp>(
392
414
loc, inf,
393
415
rewriter.create <arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
394
- lhsRealIsInfWithSignTimesRhsImag));
416
+ lhsRealIsInfWithSignTimesRhsImag,
417
+ fmf.getValue ()),
418
+ fmf.getValue ());
395
419
396
420
// Case 3: Finite numerator, infinite denominator.
397
421
Value lhsRealFinite = rewriter.create <arith::CmpFOp>(
@@ -414,22 +438,26 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
414
438
Value rhsImagIsInfWithSign = rewriter.create <math::CopySignOp>(
415
439
loc, rewriter.create <arith::SelectOp>(loc, rhsImagInfinite, one, zero),
416
440
rhsImag);
417
- Value rhsRealIsInfWithSignTimesLhsReal =
418
- rewriter. create <arith::MulFOp>( loc, lhsReal, rhsRealIsInfWithSign);
419
- Value rhsImagIsInfWithSignTimesLhsImag =
420
- rewriter. create <arith::MulFOp>( loc, lhsImag, rhsImagIsInfWithSign);
441
+ Value rhsRealIsInfWithSignTimesLhsReal = rewriter. create <arith::MulFOp>(
442
+ loc, lhsReal, rhsRealIsInfWithSign, fmf. getValue () );
443
+ Value rhsImagIsInfWithSignTimesLhsImag = rewriter. create <arith::MulFOp>(
444
+ loc, lhsImag, rhsImagIsInfWithSign, fmf. getValue () );
421
445
Value resultReal4 = rewriter.create <arith::MulFOp>(
422
446
loc, zero,
423
447
rewriter.create <arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
424
- rhsImagIsInfWithSignTimesLhsImag));
425
- Value rhsRealIsInfWithSignTimesLhsImag =
426
- rewriter.create <arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
427
- Value rhsImagIsInfWithSignTimesLhsReal =
428
- rewriter.create <arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
448
+ rhsImagIsInfWithSignTimesLhsImag,
449
+ fmf.getValue ()),
450
+ fmf.getValue ());
451
+ Value rhsRealIsInfWithSignTimesLhsImag = rewriter.create <arith::MulFOp>(
452
+ loc, lhsImag, rhsRealIsInfWithSign, fmf.getValue ());
453
+ Value rhsImagIsInfWithSignTimesLhsReal = rewriter.create <arith::MulFOp>(
454
+ loc, lhsReal, rhsImagIsInfWithSign, fmf.getValue ());
429
455
Value resultImag4 = rewriter.create <arith::MulFOp>(
430
456
loc, zero,
431
457
rewriter.create <arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
432
- rhsImagIsInfWithSignTimesLhsReal));
458
+ rhsImagIsInfWithSignTimesLhsReal,
459
+ fmf.getValue ()),
460
+ fmf.getValue ());
433
461
434
462
Value realAbsSmallerThanImagAbs = rewriter.create <arith::CmpFOp>(
435
463
loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
0 commit comments