Skip to content

Commit edac1b2

Browse files
authored
[RISCV] Promote bf16 ops to f32 with zvfbfmin (#108937)
For f16 with zvfhmin, we promote most ops and VP ops to f32. This does the same for bf16 with zvfbfmin, so the two fp types should now be in sync. There are a few places in the custom lowering where we need to check for a LMUL 8 f16/bf16 vector that can't be promoted and must be split, this extracts that out into isPromotedOpNeedingSplit. In a follow up NFC we can deduplicate the code that sets up the promotions.
1 parent e32a62c commit edac1b2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+17683
-1582
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 70 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
941941
};
942942

943943
// TODO: support more ops.
944-
static const unsigned ZvfhminPromoteOps[] = {
944+
static const unsigned ZvfhminZvfbfminPromoteOps[] = {
945945
ISD::FMINNUM, ISD::FMAXNUM, ISD::FADD, ISD::FSUB,
946946
ISD::FMUL, ISD::FMA, ISD::FDIV, ISD::FSQRT,
947947
ISD::FCEIL, ISD::FTRUNC, ISD::FFLOOR, ISD::FROUND,
@@ -951,30 +951,31 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
951951
ISD::STRICT_FMA};
952952

953953
// TODO: support more vp ops.
954-
static const unsigned ZvfhminPromoteVPOps[] = {ISD::VP_FADD,
955-
ISD::VP_FSUB,
956-
ISD::VP_FMUL,
957-
ISD::VP_FDIV,
958-
ISD::VP_FMA,
959-
ISD::VP_REDUCE_FADD,
960-
ISD::VP_REDUCE_SEQ_FADD,
961-
ISD::VP_REDUCE_FMIN,
962-
ISD::VP_REDUCE_FMAX,
963-
ISD::VP_SQRT,
964-
ISD::VP_FMINNUM,
965-
ISD::VP_FMAXNUM,
966-
ISD::VP_FCEIL,
967-
ISD::VP_FFLOOR,
968-
ISD::VP_FROUND,
969-
ISD::VP_FROUNDEVEN,
970-
ISD::VP_FROUNDTOZERO,
971-
ISD::VP_FRINT,
972-
ISD::VP_FNEARBYINT,
973-
ISD::VP_SETCC,
974-
ISD::VP_FMINIMUM,
975-
ISD::VP_FMAXIMUM,
976-
ISD::VP_REDUCE_FMINIMUM,
977-
ISD::VP_REDUCE_FMAXIMUM};
954+
static const unsigned ZvfhminZvfbfminPromoteVPOps[] = {
955+
ISD::VP_FADD,
956+
ISD::VP_FSUB,
957+
ISD::VP_FMUL,
958+
ISD::VP_FDIV,
959+
ISD::VP_FMA,
960+
ISD::VP_REDUCE_FADD,
961+
ISD::VP_REDUCE_SEQ_FADD,
962+
ISD::VP_REDUCE_FMIN,
963+
ISD::VP_REDUCE_FMAX,
964+
ISD::VP_SQRT,
965+
ISD::VP_FMINNUM,
966+
ISD::VP_FMAXNUM,
967+
ISD::VP_FCEIL,
968+
ISD::VP_FFLOOR,
969+
ISD::VP_FROUND,
970+
ISD::VP_FROUNDEVEN,
971+
ISD::VP_FROUNDTOZERO,
972+
ISD::VP_FRINT,
973+
ISD::VP_FNEARBYINT,
974+
ISD::VP_SETCC,
975+
ISD::VP_FMINIMUM,
976+
ISD::VP_FMAXIMUM,
977+
ISD::VP_REDUCE_FMINIMUM,
978+
ISD::VP_REDUCE_FMAXIMUM};
978979

979980
// Sets common operation actions on RVV floating-point vector types.
980981
const auto SetCommonVFPActions = [&](MVT VT) {
@@ -1097,20 +1098,20 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
10971098
setOperationAction(ISD::FABS, VT, Expand);
10981099
setOperationAction(ISD::FCOPYSIGN, VT, Expand);
10991100

1100-
// Custom split nxv32f16 since nxv32f32 if not legal.
1101+
// Custom split nxv32f16 since nxv32f32 is not legal.
11011102
if (VT == MVT::nxv32f16) {
1102-
setOperationAction(ZvfhminPromoteOps, VT, Custom);
1103-
setOperationAction(ZvfhminPromoteVPOps, VT, Custom);
1103+
setOperationAction(ZvfhminZvfbfminPromoteOps, VT, Custom);
1104+
setOperationAction(ZvfhminZvfbfminPromoteVPOps, VT, Custom);
11041105
continue;
11051106
}
11061107
// Add more promote ops.
11071108
MVT F32VecVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
1108-
setOperationPromotedToType(ZvfhminPromoteOps, VT, F32VecVT);
1109-
setOperationPromotedToType(ZvfhminPromoteVPOps, VT, F32VecVT);
1109+
setOperationPromotedToType(ZvfhminZvfbfminPromoteOps, VT, F32VecVT);
1110+
setOperationPromotedToType(ZvfhminZvfbfminPromoteVPOps, VT, F32VecVT);
11101111
}
11111112
}
11121113

1113-
// TODO: Could we merge some code with zvfhmin?
1114+
// TODO: merge with zvfhmin
11141115
if (Subtarget.hasVInstructionsBF16Minimal()) {
11151116
for (MVT VT : BF16VecVTs) {
11161117
if (!isTypeLegal(VT))
@@ -1139,7 +1140,16 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
11391140
setOperationAction(ISD::FABS, VT, Expand);
11401141
setOperationAction(ISD::FCOPYSIGN, VT, Expand);
11411142

1142-
// TODO: Promote to fp32.
1143+
// Custom split nxv32f16 since nxv32f32 is not legal.
1144+
if (VT == MVT::nxv32bf16) {
1145+
setOperationAction(ZvfhminZvfbfminPromoteOps, VT, Custom);
1146+
setOperationAction(ZvfhminZvfbfminPromoteVPOps, VT, Custom);
1147+
continue;
1148+
}
1149+
// Add more promote ops.
1150+
MVT F32VecVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
1151+
setOperationPromotedToType(ZvfhminZvfbfminPromoteOps, VT, F32VecVT);
1152+
setOperationPromotedToType(ZvfhminZvfbfminPromoteVPOps, VT, F32VecVT);
11431153
}
11441154
}
11451155

@@ -1375,8 +1385,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
13751385
// TODO: could split the f16 vector into two vectors and do promotion.
13761386
if (!isTypeLegal(F32VecVT))
13771387
continue;
1378-
setOperationPromotedToType(ZvfhminPromoteOps, VT, F32VecVT);
1379-
setOperationPromotedToType(ZvfhminPromoteVPOps, VT, F32VecVT);
1388+
setOperationPromotedToType(ZvfhminZvfbfminPromoteOps, VT, F32VecVT);
1389+
setOperationPromotedToType(ZvfhminZvfbfminPromoteVPOps, VT, F32VecVT);
13801390
continue;
13811391
}
13821392

@@ -6333,6 +6343,17 @@ static bool hasMaskOp(unsigned Opcode) {
63336343
return false;
63346344
}
63356345

6346+
static bool isPromotedOpNeedingSplit(SDValue Op,
6347+
const RISCVSubtarget &Subtarget) {
6348+
if (Op.getValueType() == MVT::nxv32f16 &&
6349+
(Subtarget.hasVInstructionsF16Minimal() &&
6350+
!Subtarget.hasVInstructionsF16()))
6351+
return true;
6352+
if (Op.getValueType() == MVT::nxv32bf16)
6353+
return true;
6354+
return false;
6355+
}
6356+
63366357
static SDValue SplitVectorOp(SDValue Op, SelectionDAG &DAG) {
63376358
auto [LoVT, HiVT] = DAG.GetSplitDestVTs(Op.getValueType());
63386359
SDLoc DL(Op);
@@ -6670,9 +6691,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
66706691
}
66716692
case ISD::FMAXIMUM:
66726693
case ISD::FMINIMUM:
6673-
if (Op.getValueType() == MVT::nxv32f16 &&
6674-
(Subtarget.hasVInstructionsF16Minimal() &&
6675-
!Subtarget.hasVInstructionsF16()))
6694+
if (isPromotedOpNeedingSplit(Op, Subtarget))
66766695
return SplitVectorOp(Op, DAG);
66776696
return lowerFMAXIMUM_FMINIMUM(Op, DAG, Subtarget);
66786697
case ISD::FP_EXTEND:
@@ -6688,8 +6707,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
66886707
(Subtarget.hasVInstructionsF16Minimal() &&
66896708
!Subtarget.hasVInstructionsF16())) ||
66906709
Op.getValueType().getScalarType() == MVT::bf16)) {
6691-
if (Op.getValueType() == MVT::nxv32f16 ||
6692-
Op.getValueType() == MVT::nxv32bf16)
6710+
if (isPromotedOpNeedingSplit(Op, Subtarget))
66936711
return SplitVectorOp(Op, DAG);
66946712
// int -> f32
66956713
SDLoc DL(Op);
@@ -6709,8 +6727,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
67096727
(Subtarget.hasVInstructionsF16Minimal() &&
67106728
!Subtarget.hasVInstructionsF16())) ||
67116729
Op1.getValueType().getScalarType() == MVT::bf16)) {
6712-
if (Op1.getValueType() == MVT::nxv32f16 ||
6713-
Op1.getValueType() == MVT::nxv32bf16)
6730+
if (isPromotedOpNeedingSplit(Op1, Subtarget))
67146731
return SplitVectorOp(Op, DAG);
67156732
// [b]f16 -> f32
67166733
SDLoc DL(Op);
@@ -6942,9 +6959,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
69426959
case ISD::FRINT:
69436960
case ISD::FROUND:
69446961
case ISD::FROUNDEVEN:
6945-
if (Op.getValueType() == MVT::nxv32f16 &&
6946-
(Subtarget.hasVInstructionsF16Minimal() &&
6947-
!Subtarget.hasVInstructionsF16()))
6962+
if (isPromotedOpNeedingSplit(Op, Subtarget))
69486963
return SplitVectorOp(Op, DAG);
69496964
return lowerFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget);
69506965
case ISD::LRINT:
@@ -7002,9 +7017,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
70027017
case ISD::VP_REDUCE_FMAX:
70037018
case ISD::VP_REDUCE_FMINIMUM:
70047019
case ISD::VP_REDUCE_FMAXIMUM:
7005-
if (Op.getOperand(1).getValueType() == MVT::nxv32f16 &&
7006-
(Subtarget.hasVInstructionsF16Minimal() &&
7007-
!Subtarget.hasVInstructionsF16()))
7020+
if (isPromotedOpNeedingSplit(Op.getOperand(1), Subtarget))
70087021
return SplitVectorReductionOp(Op, DAG);
70097022
return lowerVPREDUCE(Op, DAG);
70107023
case ISD::VP_REDUCE_AND:
@@ -7251,9 +7264,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
72517264
return DAG.getSetCC(DL, VT, RHS, LHS, CCVal);
72527265
}
72537266

7254-
if (Op.getOperand(0).getSimpleValueType() == MVT::nxv32f16 &&
7255-
(Subtarget.hasVInstructionsF16Minimal() &&
7256-
!Subtarget.hasVInstructionsF16()))
7267+
if (isPromotedOpNeedingSplit(Op.getOperand(0), Subtarget))
72577268
return SplitVectorOp(Op, DAG);
72587269

72597270
return lowerFixedLengthVectorSetccToRVV(Op, DAG);
@@ -7295,9 +7306,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
72957306
case ISD::FMA:
72967307
case ISD::FMINNUM:
72977308
case ISD::FMAXNUM:
7298-
if (Op.getValueType() == MVT::nxv32f16 &&
7299-
(Subtarget.hasVInstructionsF16Minimal() &&
7300-
!Subtarget.hasVInstructionsF16()))
7309+
if (isPromotedOpNeedingSplit(Op, Subtarget))
73017310
return SplitVectorOp(Op, DAG);
73027311
[[fallthrough]];
73037312
case ISD::AVGFLOORS:
@@ -7345,9 +7354,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
73457354
case ISD::FCOPYSIGN:
73467355
if (Op.getValueType() == MVT::f16 || Op.getValueType() == MVT::bf16)
73477356
return lowerFCOPYSIGN(Op, DAG, Subtarget);
7348-
if (Op.getValueType() == MVT::nxv32f16 &&
7349-
(Subtarget.hasVInstructionsF16Minimal() &&
7350-
!Subtarget.hasVInstructionsF16()))
7357+
if (isPromotedOpNeedingSplit(Op, Subtarget))
73517358
return SplitVectorOp(Op, DAG);
73527359
return lowerFixedLengthVectorFCOPYSIGNToRVV(Op, DAG);
73537360
case ISD::STRICT_FADD:
@@ -7356,9 +7363,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
73567363
case ISD::STRICT_FDIV:
73577364
case ISD::STRICT_FSQRT:
73587365
case ISD::STRICT_FMA:
7359-
if (Op.getValueType() == MVT::nxv32f16 &&
7360-
(Subtarget.hasVInstructionsF16Minimal() &&
7361-
!Subtarget.hasVInstructionsF16()))
7366+
if (isPromotedOpNeedingSplit(Op, Subtarget))
73627367
return SplitStrictFPVectorOp(Op, DAG);
73637368
return lowerToScalableOp(Op, DAG);
73647369
case ISD::STRICT_FSETCC:
@@ -7415,9 +7420,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
74157420
case ISD::VP_FMINNUM:
74167421
case ISD::VP_FMAXNUM:
74177422
case ISD::VP_FCOPYSIGN:
7418-
if (Op.getValueType() == MVT::nxv32f16 &&
7419-
(Subtarget.hasVInstructionsF16Minimal() &&
7420-
!Subtarget.hasVInstructionsF16()))
7423+
if (isPromotedOpNeedingSplit(Op, Subtarget))
74217424
return SplitVPOp(Op, DAG);
74227425
[[fallthrough]];
74237426
case ISD::VP_SRA:
@@ -7443,8 +7446,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
74437446
(Subtarget.hasVInstructionsF16Minimal() &&
74447447
!Subtarget.hasVInstructionsF16())) ||
74457448
Op.getValueType().getScalarType() == MVT::bf16)) {
7446-
if (Op.getValueType() == MVT::nxv32f16 ||
7447-
Op.getValueType() == MVT::nxv32bf16)
7449+
if (isPromotedOpNeedingSplit(Op, Subtarget))
74487450
return SplitVectorOp(Op, DAG);
74497451
// int -> f32
74507452
SDLoc DL(Op);
@@ -7464,8 +7466,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
74647466
(Subtarget.hasVInstructionsF16Minimal() &&
74657467
!Subtarget.hasVInstructionsF16())) ||
74667468
Op1.getValueType().getScalarType() == MVT::bf16)) {
7467-
if (Op1.getValueType() == MVT::nxv32f16 ||
7468-
Op1.getValueType() == MVT::nxv32bf16)
7469+
if (isPromotedOpNeedingSplit(Op1, Subtarget))
74697470
return SplitVectorOp(Op, DAG);
74707471
// [b]f16 -> f32
74717472
SDLoc DL(Op);
@@ -7478,9 +7479,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
74787479
}
74797480
return lowerVPFPIntConvOp(Op, DAG);
74807481
case ISD::VP_SETCC:
7481-
if (Op.getOperand(0).getSimpleValueType() == MVT::nxv32f16 &&
7482-
(Subtarget.hasVInstructionsF16Minimal() &&
7483-
!Subtarget.hasVInstructionsF16()))
7482+
if (isPromotedOpNeedingSplit(Op.getOperand(0), Subtarget))
74847483
return SplitVPOp(Op, DAG);
74857484
if (Op.getOperand(0).getSimpleValueType().getVectorElementType() == MVT::i1)
74867485
return lowerVPSetCCMaskOp(Op, DAG);
@@ -7515,16 +7514,12 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
75157514
case ISD::VP_FROUND:
75167515
case ISD::VP_FROUNDEVEN:
75177516
case ISD::VP_FROUNDTOZERO:
7518-
if (Op.getValueType() == MVT::nxv32f16 &&
7519-
(Subtarget.hasVInstructionsF16Minimal() &&
7520-
!Subtarget.hasVInstructionsF16()))
7517+
if (isPromotedOpNeedingSplit(Op, Subtarget))
75217518
return SplitVPOp(Op, DAG);
75227519
return lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget);
75237520
case ISD::VP_FMAXIMUM:
75247521
case ISD::VP_FMINIMUM:
7525-
if (Op.getValueType() == MVT::nxv32f16 &&
7526-
(Subtarget.hasVInstructionsF16Minimal() &&
7527-
!Subtarget.hasVInstructionsF16()))
7522+
if (isPromotedOpNeedingSplit(Op, Subtarget))
75287523
return SplitVPOp(Op, DAG);
75297524
return lowerFMAXIMUM_FMINIMUM(Op, DAG, Subtarget);
75307525
case ISD::EXPERIMENTAL_VP_SPLICE:

0 commit comments

Comments
 (0)