Skip to content

Commit a9ce181

Browse files
committed
[NVPTX] Fix bugs involving maximum/minimum and bf16
We would crash on sufficiently old NV hardware (Volta or so) due to incorrectly marking certain operations legal.
1 parent ea1f05e commit a9ce181

File tree

2 files changed

+1382
-253
lines changed

2 files changed

+1382
-253
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -429,29 +429,50 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
429429

430430
auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
431431
LegalizeAction NoF16Action) {
432-
setOperationAction(Op, VT, STI.allowFP16Math() ? Action : NoF16Action);
432+
bool IsOpSupported = STI.allowFP16Math();
433+
switch (Op) {
434+
// Several FP16 instructions are available on sm_80 only.
435+
case ISD::FMINNUM:
436+
case ISD::FMAXNUM:
437+
case ISD::FMAXNUM_IEEE:
438+
case ISD::FMINNUM_IEEE:
439+
case ISD::FMAXIMUM:
440+
case ISD::FMINIMUM:
441+
IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
442+
break;
443+
}
444+
setOperationAction(Op, VT, IsOpSupported ? Action : NoF16Action);
433445
};
434446

435447
auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
436448
LegalizeAction NoBF16Action) {
437449
bool IsOpSupported = STI.hasBF16Math();
438-
// Few instructions are available on sm_90 only
439-
switch(Op) {
440-
case ISD::FADD:
441-
case ISD::FMUL:
442-
case ISD::FSUB:
443-
case ISD::SELECT:
444-
case ISD::SELECT_CC:
445-
case ISD::SETCC:
446-
case ISD::FEXP2:
447-
case ISD::FCEIL:
448-
case ISD::FFLOOR:
449-
case ISD::FNEARBYINT:
450-
case ISD::FRINT:
451-
case ISD::FROUNDEVEN:
452-
case ISD::FTRUNC:
453-
IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 78;
454-
break;
450+
switch (Op) {
451+
// Several BF16 instructions are available on sm_90 only.
452+
case ISD::FADD:
453+
case ISD::FMUL:
454+
case ISD::FSUB:
455+
case ISD::SELECT:
456+
case ISD::SELECT_CC:
457+
case ISD::SETCC:
458+
case ISD::FEXP2:
459+
case ISD::FCEIL:
460+
case ISD::FFLOOR:
461+
case ISD::FNEARBYINT:
462+
case ISD::FRINT:
463+
case ISD::FROUNDEVEN:
464+
case ISD::FTRUNC:
465+
IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 78;
466+
break;
467+
// Several BF16 instructions are available on sm_80 only.
468+
case ISD::FMINNUM:
469+
case ISD::FMAXNUM:
470+
case ISD::FMAXNUM_IEEE:
471+
case ISD::FMINNUM_IEEE:
472+
case ISD::FMAXIMUM:
473+
case ISD::FMINIMUM:
474+
IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
475+
break;
455476
}
456477
setOperationAction(
457478
Op, VT, IsOpSupported ? Action : NoBF16Action);
@@ -838,26 +859,23 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
838859
AddPromotedToType(Op, MVT::bf16, MVT::f32);
839860
}
840861

841-
// max.f16, max.f16x2 and max.NaN are supported on sm_80+.
842-
auto GetMinMaxAction = [&](LegalizeAction NotSm80Action) {
843-
bool IsAtLeastSm80 = STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
844-
return IsAtLeastSm80 ? Legal : NotSm80Action;
845-
};
846862
for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
847-
setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Promote), Promote);
848863
setOperationAction(Op, MVT::f32, Legal);
849864
setOperationAction(Op, MVT::f64, Legal);
850-
setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand);
865+
setFP16OperationAction(Op, MVT::f16, Legal, Promote);
866+
setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
851867
setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
852868
setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
853869
if (getOperationAction(Op, MVT::bf16) == Promote)
854870
AddPromotedToType(Op, MVT::bf16, MVT::f32);
855871
}
872+
bool SupportsF32MinMaxNaN =
873+
STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
856874
for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
857-
setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Expand), Expand);
858-
setFP16OperationAction(Op, MVT::bf16, Legal, Expand);
859-
setOperationAction(Op, MVT::f32, GetMinMaxAction(Expand));
860-
setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand);
875+
setOperationAction(Op, MVT::f32, SupportsF32MinMaxNaN ? Legal : Expand);
876+
setFP16OperationAction(Op, MVT::f16, Legal, Expand);
877+
setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
878+
setBF16OperationAction(Op, MVT::bf16, Legal, Expand);
861879
setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
862880
}
863881

0 commit comments

Comments
 (0)