@@ -150,8 +150,6 @@ def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
150
150
151
151
def doMulWide : Predicate<"doMulWide">;
152
152
153
- def allowFMA : Predicate<"allowFMA()">;
154
- def noFMA : Predicate<"!allowFMA()">;
155
153
def allowUnsafeFPMath : Predicate<"allowUnsafeFPMath()">;
156
154
def noUnsafeFPMath : Predicate<"!allowUnsafeFPMath()">;
157
155
@@ -367,167 +365,89 @@ multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
367
365
// This multiclass should be used for nodes that can be folded to make fma ops.
368
366
// In this case, we use the ".rn" variant when FMA is disabled, as this behaves
369
367
// just like the non ".rn" op, but prevents ptxas from creating FMAs.
370
- multiclass F3_fma_component<string OpcStr, SDNode OpNode> {
371
- def f64rr :
372
- NVPTXInst<(outs Float64Regs:$dst),
373
- (ins Float64Regs:$a, Float64Regs:$b),
374
- !strconcat(OpcStr, ".f64 \t$dst, $a, $b;"),
375
- [(set f64:$dst, (OpNode f64:$a, f64:$b))]>,
376
- Requires<[allowFMA]>;
377
- def f64ri :
378
- NVPTXInst<(outs Float64Regs:$dst),
379
- (ins Float64Regs:$a, f64imm:$b),
380
- !strconcat(OpcStr, ".f64 \t$dst, $a, $b;"),
381
- [(set f64:$dst, (OpNode f64:$a, fpimm:$b))]>,
382
- Requires<[allowFMA]>;
383
- def f32rr_ftz :
384
- NVPTXInst<(outs Float32Regs:$dst),
385
- (ins Float32Regs:$a, Float32Regs:$b),
386
- !strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b;"),
387
- [(set f32:$dst, (OpNode f32:$a, f32:$b))]>,
388
- Requires<[allowFMA, doF32FTZ]>;
389
- def f32ri_ftz :
390
- NVPTXInst<(outs Float32Regs:$dst),
391
- (ins Float32Regs:$a, f32imm:$b),
392
- !strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b;"),
393
- [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>,
394
- Requires<[allowFMA, doF32FTZ]>;
395
- def f32rr :
396
- NVPTXInst<(outs Float32Regs:$dst),
397
- (ins Float32Regs:$a, Float32Regs:$b),
398
- !strconcat(OpcStr, ".f32 \t$dst, $a, $b;"),
399
- [(set f32:$dst, (OpNode f32:$a, f32:$b))]>,
400
- Requires<[allowFMA]>;
401
- def f32ri :
402
- NVPTXInst<(outs Float32Regs:$dst),
403
- (ins Float32Regs:$a, f32imm:$b),
404
- !strconcat(OpcStr, ".f32 \t$dst, $a, $b;"),
405
- [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>,
406
- Requires<[allowFMA]>;
407
-
408
- def f16rr_ftz :
409
- NVPTXInst<(outs Int16Regs:$dst),
410
- (ins Int16Regs:$a, Int16Regs:$b),
411
- !strconcat(OpcStr, ".ftz.f16 \t$dst, $a, $b;"),
412
- [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
413
- Requires<[useFP16Math, allowFMA, doF32FTZ]>;
414
- def f16rr :
415
- NVPTXInst<(outs Int16Regs:$dst),
416
- (ins Int16Regs:$a, Int16Regs:$b),
417
- !strconcat(OpcStr, ".f16 \t$dst, $a, $b;"),
418
- [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
419
- Requires<[useFP16Math, allowFMA]>;
368
+ multiclass F3<string op_str, SDPatternOperator op_pat> {
369
+ def f64rr :
370
+ NVPTXInst<(outs Float64Regs:$dst),
371
+ (ins Float64Regs:$a, Float64Regs:$b),
372
+ op_str # ".f64 \t$dst, $a, $b;",
373
+ [(set f64:$dst, (op_pat f64:$a, f64:$b))]>;
374
+ def f64ri :
375
+ NVPTXInst<(outs Float64Regs:$dst),
376
+ (ins Float64Regs:$a, f64imm:$b),
377
+ op_str # ".f64 \t$dst, $a, $b;",
378
+ [(set f64:$dst, (op_pat f64:$a, fpimm:$b))]>;
379
+ def f32rr_ftz :
380
+ NVPTXInst<(outs Float32Regs:$dst),
381
+ (ins Float32Regs:$a, Float32Regs:$b),
382
+ op_str # ".ftz.f32 \t$dst, $a, $b;",
383
+ [(set f32:$dst, (op_pat f32:$a, f32:$b))]>,
384
+ Requires<[doF32FTZ]>;
385
+ def f32ri_ftz :
386
+ NVPTXInst<(outs Float32Regs:$dst),
387
+ (ins Float32Regs:$a, f32imm:$b),
388
+ op_str # ".ftz.f32 \t$dst, $a, $b;",
389
+ [(set f32:$dst, (op_pat f32:$a, fpimm:$b))]>,
390
+ Requires<[doF32FTZ]>;
391
+ def f32rr :
392
+ NVPTXInst<(outs Float32Regs:$dst),
393
+ (ins Float32Regs:$a, Float32Regs:$b),
394
+ op_str # ".f32 \t$dst, $a, $b;",
395
+ [(set f32:$dst, (op_pat f32:$a, f32:$b))]>;
396
+ def f32ri :
397
+ NVPTXInst<(outs Float32Regs:$dst),
398
+ (ins Float32Regs:$a, f32imm:$b),
399
+ op_str # ".f32 \t$dst, $a, $b;",
400
+ [(set f32:$dst, (op_pat f32:$a, fpimm:$b))]>;
401
+
402
+ def f16rr_ftz :
403
+ NVPTXInst<(outs Int16Regs:$dst),
404
+ (ins Int16Regs:$a, Int16Regs:$b),
405
+ op_str # ".ftz.f16 \t$dst, $a, $b;",
406
+ [(set f16:$dst, (op_pat f16:$a, f16:$b))]>,
407
+ Requires<[useFP16Math, doF32FTZ]>;
408
+ def f16rr :
409
+ NVPTXInst<(outs Int16Regs:$dst),
410
+ (ins Int16Regs:$a, Int16Regs:$b),
411
+ op_str # ".f16 \t$dst, $a, $b;",
412
+ [(set f16:$dst, (op_pat f16:$a, f16:$b))]>,
413
+ Requires<[useFP16Math]>;
414
+
415
+ def f16x2rr_ftz :
416
+ NVPTXInst<(outs Int32Regs:$dst),
417
+ (ins Int32Regs:$a, Int32Regs:$b),
418
+ op_str # ".ftz.f16x2 \t$dst, $a, $b;",
419
+ [(set v2f16:$dst, (op_pat v2f16:$a, v2f16:$b))]>,
420
+ Requires<[useFP16Math, doF32FTZ]>;
421
+ def f16x2rr :
422
+ NVPTXInst<(outs Int32Regs:$dst),
423
+ (ins Int32Regs:$a, Int32Regs:$b),
424
+ op_str # ".f16x2 \t$dst, $a, $b;",
425
+ [(set v2f16:$dst, (op_pat v2f16:$a, v2f16:$b))]>,
426
+ Requires<[useFP16Math]>;
427
+ def bf16rr :
428
+ NVPTXInst<(outs Int16Regs:$dst),
429
+ (ins Int16Regs:$a, Int16Regs:$b),
430
+ op_str # ".bf16 \t$dst, $a, $b;",
431
+ [(set bf16:$dst, (op_pat bf16:$a, bf16:$b))]>,
432
+ Requires<[hasBF16Math]>;
433
+
434
+ def bf16x2rr :
435
+ NVPTXInst<(outs Int32Regs:$dst),
436
+ (ins Int32Regs:$a, Int32Regs:$b),
437
+ op_str # ".bf16x2 \t$dst, $a, $b;",
438
+ [(set v2bf16:$dst, (op_pat v2bf16:$a, v2bf16:$b))]>,
439
+ Requires<[hasBF16Math]>;
440
+ }
420
441
421
- def f16x2rr_ftz :
422
- NVPTXInst<(outs Int32Regs:$dst),
423
- (ins Int32Regs:$a, Int32Regs:$b),
424
- !strconcat(OpcStr, ".ftz.f16x2 \t$dst, $a, $b;"),
425
- [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
426
- Requires<[useFP16Math, allowFMA, doF32FTZ]>;
427
- def f16x2rr :
428
- NVPTXInst<(outs Int32Regs:$dst),
429
- (ins Int32Regs:$a, Int32Regs:$b),
430
- !strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
431
- [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
432
- Requires<[useFP16Math, allowFMA]>;
433
- def bf16rr :
434
- NVPTXInst<(outs Int16Regs:$dst),
435
- (ins Int16Regs:$a, Int16Regs:$b),
436
- !strconcat(OpcStr, ".bf16 \t$dst, $a, $b;"),
437
- [(set bf16:$dst, (OpNode bf16:$a, bf16:$b))]>,
438
- Requires<[hasBF16Math, allowFMA]>;
442
+ class BinOpAllowsFMA<SDPatternOperator operator>
443
+ : PatFrag<(ops node:$A, node:$B),
444
+ (operator node:$A, node:$B), [{
445
+ return allowFMA() || N->getFlags().hasAllowContract();;
446
+ }]>;
439
447
440
- def bf16x2rr :
441
- NVPTXInst<(outs Int32Regs:$dst),
442
- (ins Int32Regs:$a, Int32Regs:$b),
443
- !strconcat(OpcStr, ".bf16x2 \t$dst, $a, $b;"),
444
- [(set v2bf16:$dst, (OpNode v2bf16:$a, v2bf16:$b))]>,
445
- Requires<[hasBF16Math, allowFMA]>;
446
- // These have strange names so we don't perturb existing mir tests.
447
- def _rnf64rr :
448
- NVPTXInst<(outs Float64Regs:$dst),
449
- (ins Float64Regs:$a, Float64Regs:$b),
450
- !strconcat(OpcStr, ".rn.f64 \t$dst, $a, $b;"),
451
- [(set f64:$dst, (OpNode f64:$a, f64:$b))]>,
452
- Requires<[noFMA]>;
453
- def _rnf64ri :
454
- NVPTXInst<(outs Float64Regs:$dst),
455
- (ins Float64Regs:$a, f64imm:$b),
456
- !strconcat(OpcStr, ".rn.f64 \t$dst, $a, $b;"),
457
- [(set f64:$dst, (OpNode f64:$a, fpimm:$b))]>,
458
- Requires<[noFMA]>;
459
- def _rnf32rr_ftz :
460
- NVPTXInst<(outs Float32Regs:$dst),
461
- (ins Float32Regs:$a, Float32Regs:$b),
462
- !strconcat(OpcStr, ".rn.ftz.f32 \t$dst, $a, $b;"),
463
- [(set f32:$dst, (OpNode f32:$a, Float32Regs:$b))]>,
464
- Requires<[noFMA, doF32FTZ]>;
465
- def _rnf32ri_ftz :
466
- NVPTXInst<(outs Float32Regs:$dst),
467
- (ins Float32Regs:$a, f32imm:$b),
468
- !strconcat(OpcStr, ".rn.ftz.f32 \t$dst, $a, $b;"),
469
- [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>,
470
- Requires<[noFMA, doF32FTZ]>;
471
- def _rnf32rr :
472
- NVPTXInst<(outs Float32Regs:$dst),
473
- (ins Float32Regs:$a, Float32Regs:$b),
474
- !strconcat(OpcStr, ".rn.f32 \t$dst, $a, $b;"),
475
- [(set f32:$dst, (OpNode f32:$a, f32:$b))]>,
476
- Requires<[noFMA]>;
477
- def _rnf32ri :
478
- NVPTXInst<(outs Float32Regs:$dst),
479
- (ins Float32Regs:$a, f32imm:$b),
480
- !strconcat(OpcStr, ".rn.f32 \t$dst, $a, $b;"),
481
- [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>,
482
- Requires<[noFMA]>;
483
- def _rnf16rr_ftz :
484
- NVPTXInst<(outs Int16Regs:$dst),
485
- (ins Int16Regs:$a, Int16Regs:$b),
486
- !strconcat(OpcStr, ".rn.ftz.f16 \t$dst, $a, $b;"),
487
- [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
488
- Requires<[useFP16Math, noFMA, doF32FTZ]>;
489
- def _rnf16rr :
490
- NVPTXInst<(outs Int16Regs:$dst),
491
- (ins Int16Regs:$a, Int16Regs:$b),
492
- !strconcat(OpcStr, ".rn.f16 \t$dst, $a, $b;"),
493
- [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
494
- Requires<[useFP16Math, noFMA]>;
495
- def _rnf16x2rr_ftz :
496
- NVPTXInst<(outs Int32Regs:$dst),
497
- (ins Int32Regs:$a, Int32Regs:$b),
498
- !strconcat(OpcStr, ".rn.ftz.f16x2 \t$dst, $a, $b;"),
499
- [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
500
- Requires<[useFP16Math, noFMA, doF32FTZ]>;
501
- def _rnf16x2rr :
502
- NVPTXInst<(outs Int32Regs:$dst),
503
- (ins Int32Regs:$a, Int32Regs:$b),
504
- !strconcat(OpcStr, ".rn.f16x2 \t$dst, $a, $b;"),
505
- [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
506
- Requires<[useFP16Math, noFMA]>;
507
- def _rnbf16rr_ftz :
508
- NVPTXInst<(outs Int16Regs:$dst),
509
- (ins Int16Regs:$a, Int16Regs:$b),
510
- !strconcat(OpcStr, ".rn.ftz.bf16 \t$dst, $a, $b;"),
511
- [(set bf16:$dst, (OpNode bf16:$a, bf16:$b))]>,
512
- Requires<[hasBF16Math, noFMA, doF32FTZ]>;
513
- def _rnbf16rr :
514
- NVPTXInst<(outs Int16Regs:$dst),
515
- (ins Int16Regs:$a, Int16Regs:$b),
516
- !strconcat(OpcStr, ".rn.bf16 \t$dst, $a, $b;"),
517
- [(set bf16:$dst, (OpNode bf16:$a, bf16:$b))]>,
518
- Requires<[hasBF16Math, noFMA]>;
519
- def _rnbf16x2rr_ftz :
520
- NVPTXInst<(outs Int32Regs:$dst),
521
- (ins Int32Regs:$a, Int32Regs:$b),
522
- !strconcat(OpcStr, ".rn.ftz.bf16x2 \t$dst, $a, $b;"),
523
- [(set v2bf16:$dst, (OpNode v2bf16:$a, v2bf16:$b))]>,
524
- Requires<[hasBF16Math, noFMA, doF32FTZ]>;
525
- def _rnbf16x2rr :
526
- NVPTXInst<(outs Int32Regs:$dst),
527
- (ins Int32Regs:$a, Int32Regs:$b),
528
- !strconcat(OpcStr, ".rn.bf16x2 \t$dst, $a, $b;"),
529
- [(set v2bf16:$dst, (OpNode v2bf16:$a, v2bf16:$b))]>,
530
- Requires<[hasBF16Math, noFMA]>;
448
+ multiclass F3_fma_component<string op_str, SDNode op_node> {
449
+ defm "" : F3<op_str, BinOpAllowsFMA<op_node>>;
450
+ defm _rn : F3<op_str # ".rn", op_node>;
531
451
}
532
452
533
453
// Template for operations which take two f32 or f64 operands. Provides three
0 commit comments