@@ -429,29 +429,50 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
429
429
430
430
auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
431
431
LegalizeAction NoF16Action) {
432
- setOperationAction (Op, VT, STI.allowFP16Math () ? Action : NoF16Action);
432
+ bool IsOpSupported = STI.allowFP16Math ();
433
+ switch (Op) {
434
+ // Several FP16 instructions are available on sm_80 only.
435
+ case ISD::FMINNUM:
436
+ case ISD::FMAXNUM:
437
+ case ISD::FMAXNUM_IEEE:
438
+ case ISD::FMINNUM_IEEE:
439
+ case ISD::FMAXIMUM:
440
+ case ISD::FMINIMUM:
441
+ IsOpSupported &= STI.getSmVersion () >= 80 && STI.getPTXVersion () >= 70 ;
442
+ break ;
443
+ }
444
+ setOperationAction (Op, VT, IsOpSupported ? Action : NoF16Action);
433
445
};
434
446
435
447
auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
436
448
LegalizeAction NoBF16Action) {
437
449
bool IsOpSupported = STI.hasBF16Math ();
438
- // Few instructions are available on sm_90 only
439
- switch (Op) {
440
- case ISD::FADD:
441
- case ISD::FMUL:
442
- case ISD::FSUB:
443
- case ISD::SELECT:
444
- case ISD::SELECT_CC:
445
- case ISD::SETCC:
446
- case ISD::FEXP2:
447
- case ISD::FCEIL:
448
- case ISD::FFLOOR:
449
- case ISD::FNEARBYINT:
450
- case ISD::FRINT:
451
- case ISD::FROUNDEVEN:
452
- case ISD::FTRUNC:
453
- IsOpSupported = STI.getSmVersion () >= 90 && STI.getPTXVersion () >= 78 ;
454
- break ;
450
+ switch (Op) {
451
+ // Several BF16 instructions are available on sm_90 only.
452
+ case ISD::FADD:
453
+ case ISD::FMUL:
454
+ case ISD::FSUB:
455
+ case ISD::SELECT:
456
+ case ISD::SELECT_CC:
457
+ case ISD::SETCC:
458
+ case ISD::FEXP2:
459
+ case ISD::FCEIL:
460
+ case ISD::FFLOOR:
461
+ case ISD::FNEARBYINT:
462
+ case ISD::FRINT:
463
+ case ISD::FROUNDEVEN:
464
+ case ISD::FTRUNC:
465
+ IsOpSupported = STI.getSmVersion () >= 90 && STI.getPTXVersion () >= 78 ;
466
+ break ;
467
+ // Several BF16 instructions are available on sm_80 only.
468
+ case ISD::FMINNUM:
469
+ case ISD::FMAXNUM:
470
+ case ISD::FMAXNUM_IEEE:
471
+ case ISD::FMINNUM_IEEE:
472
+ case ISD::FMAXIMUM:
473
+ case ISD::FMINIMUM:
474
+ IsOpSupported &= STI.getSmVersion () >= 80 && STI.getPTXVersion () >= 70 ;
475
+ break ;
455
476
}
456
477
setOperationAction (
457
478
Op, VT, IsOpSupported ? Action : NoBF16Action);
@@ -838,26 +859,23 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
838
859
AddPromotedToType (Op, MVT::bf16 , MVT::f32 );
839
860
}
840
861
841
- // max.f16, max.f16x2 and max.NaN are supported on sm_80+.
842
- auto GetMinMaxAction = [&](LegalizeAction NotSm80Action) {
843
- bool IsAtLeastSm80 = STI.getSmVersion () >= 80 && STI.getPTXVersion () >= 70 ;
844
- return IsAtLeastSm80 ? Legal : NotSm80Action;
845
- };
846
862
for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
847
- setFP16OperationAction (Op, MVT::f16 , GetMinMaxAction (Promote), Promote);
848
863
setOperationAction (Op, MVT::f32 , Legal);
849
864
setOperationAction (Op, MVT::f64 , Legal);
850
- setFP16OperationAction (Op, MVT::v2f16, GetMinMaxAction (Expand), Expand);
865
+ setFP16OperationAction (Op, MVT::f16 , Legal, Promote);
866
+ setFP16OperationAction (Op, MVT::v2f16, Legal, Expand);
851
867
setBF16OperationAction (Op, MVT::v2bf16, Legal, Expand);
852
868
setBF16OperationAction (Op, MVT::bf16 , Legal, Promote);
853
869
if (getOperationAction (Op, MVT::bf16 ) == Promote)
854
870
AddPromotedToType (Op, MVT::bf16 , MVT::f32 );
855
871
}
872
+ bool SupportsF32MinMaxNaN =
873
+ STI.getSmVersion () >= 80 && STI.getPTXVersion () >= 70 ;
856
874
for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
857
- setFP16OperationAction (Op, MVT::f16 , GetMinMaxAction (Expand), Expand);
858
- setFP16OperationAction (Op, MVT::bf16 , Legal, Expand);
859
- setOperationAction (Op, MVT::f32 , GetMinMaxAction ( Expand) );
860
- setFP16OperationAction (Op, MVT::v2f16, GetMinMaxAction (Expand) , Expand);
875
+ setOperationAction (Op, MVT::f32 , SupportsF32MinMaxNaN ? Legal : Expand);
876
+ setFP16OperationAction (Op, MVT::f16 , Legal, Expand);
877
+ setFP16OperationAction (Op, MVT::v2f16, Legal, Expand);
878
+ setBF16OperationAction (Op, MVT::bf16 , Legal , Expand);
861
879
setBF16OperationAction (Op, MVT::v2bf16, Legal, Expand);
862
880
}
863
881
0 commit comments