Skip to content

Commit c638ddd

Browse files
[LLVM][CodeGen] Add lowering for scalable vector bfloat operations.
Specifically: fabs, fadd, fceil, fdiv, ffloor, fma, fmax, fmaxnm, fmin, fminnm, fmul, fnearbyint, fneg, frint, fround, froundeven, fsub, fsqrt & ftrunc
1 parent 3e3780e commit c638ddd

File tree

8 files changed

+1218
-0
lines changed

8 files changed

+1218
-0
lines changed

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1565,6 +1565,12 @@ class SelectionDAG {
15651565
SDValue getSetFPEnv(SDValue Chain, const SDLoc &dl, SDValue Ptr, EVT MemVT,
15661566
MachineMemOperand *MMO);
15671567

1568+
SDValue getExtractSubvector(const SDLoc &DL, EVT VT, SDValue V,
1569+
uint64_t Idx) {
1570+
return getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, V,
1571+
getVectorIdxConstant(Idx, DL));
1572+
}
1573+
15681574
/// Construct a node to track a Value* through the backend.
15691575
SDValue getSrcValue(const Value *v);
15701576

llvm/include/llvm/CodeGen/SelectionDAGNodes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,9 @@ class SDValue {
191191
return getValueType().getSimpleVT();
192192
}
193193

194+
/// Return the scalar ValueType of the referenced return value.
195+
EVT getScalarValueType() const { return getValueType().getScalarType(); }
196+
194197
/// Returns the size of the value in bits.
195198
///
196199
/// If the value type is a scalable vector type, the scalable property will

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,12 +1663,32 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
16631663
for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
16641664
setOperationAction(ISD::BITCAST, VT, Custom);
16651665
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
1666+
setOperationAction(ISD::FCEIL, VT, Custom);
1667+
setOperationAction(ISD::FDIV, VT, Custom);
1668+
setOperationAction(ISD::FFLOOR, VT, Custom);
1669+
setOperationAction(ISD::FMA, VT, Custom);
1670+
setOperationAction(ISD::FMAXIMUM, VT, Custom);
1671+
setOperationAction(ISD::FMAXNUM, VT, Custom);
1672+
setOperationAction(ISD::FMINIMUM, VT, Custom);
1673+
setOperationAction(ISD::FMINNUM, VT, Custom);
1674+
setOperationAction(ISD::FNEARBYINT, VT, Custom);
16661675
setOperationAction(ISD::FP_EXTEND, VT, Custom);
16671676
setOperationAction(ISD::FP_ROUND, VT, Custom);
1677+
setOperationAction(ISD::FRINT, VT, Custom);
1678+
setOperationAction(ISD::FROUND, VT, Custom);
1679+
setOperationAction(ISD::FROUNDEVEN, VT, Custom);
1680+
setOperationAction(ISD::FSQRT, VT, Custom);
1681+
setOperationAction(ISD::FTRUNC, VT, Custom);
16681682
setOperationAction(ISD::MLOAD, VT, Custom);
16691683
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
16701684
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
16711685
setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
1686+
1687+
if (!Subtarget->hasSVEB16B16()) {
1688+
setOperationAction(ISD::FADD, VT, Custom);
1689+
setOperationAction(ISD::FMUL, VT, Custom);
1690+
setOperationAction(ISD::FSUB, VT, Custom);
1691+
}
16721692
}
16731693

16741694
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
@@ -7051,32 +7071,58 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
70517071
case ISD::UMULO:
70527072
return LowerXALUO(Op, DAG);
70537073
case ISD::FADD:
7074+
if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
7075+
return LowerBFloatOp(Op, DAG);
70547076
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FADD_PRED);
70557077
case ISD::FSUB:
7078+
if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
7079+
return LowerBFloatOp(Op, DAG);
70567080
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED);
70577081
case ISD::FMUL:
7082+
if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
7083+
return LowerBFloatOp(Op, DAG);
70587084
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
70597085
case ISD::FMA:
7086+
if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
7087+
return LowerBFloatOp(Op, DAG);
70607088
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
70617089
case ISD::FDIV:
7090+
if (Op.getScalarValueType() == MVT::bf16)
7091+
return LowerBFloatOp(Op, DAG);
70627092
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED);
70637093
case ISD::FNEG:
70647094
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU);
70657095
case ISD::FCEIL:
7096+
if (Op.getScalarValueType() == MVT::bf16)
7097+
return LowerBFloatOp(Op, DAG);
70667098
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FCEIL_MERGE_PASSTHRU);
70677099
case ISD::FFLOOR:
7100+
if (Op.getScalarValueType() == MVT::bf16)
7101+
return LowerBFloatOp(Op, DAG);
70687102
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FFLOOR_MERGE_PASSTHRU);
70697103
case ISD::FNEARBYINT:
7104+
if (Op.getScalarValueType() == MVT::bf16)
7105+
return LowerBFloatOp(Op, DAG);
70707106
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEARBYINT_MERGE_PASSTHRU);
70717107
case ISD::FRINT:
7108+
if (Op.getScalarValueType() == MVT::bf16)
7109+
return LowerBFloatOp(Op, DAG);
70727110
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FRINT_MERGE_PASSTHRU);
70737111
case ISD::FROUND:
7112+
if (Op.getScalarValueType() == MVT::bf16)
7113+
return LowerBFloatOp(Op, DAG);
70747114
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FROUND_MERGE_PASSTHRU);
70757115
case ISD::FROUNDEVEN:
7116+
if (Op.getScalarValueType() == MVT::bf16)
7117+
return LowerBFloatOp(Op, DAG);
70767118
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU);
70777119
case ISD::FTRUNC:
7120+
if (Op.getScalarValueType() == MVT::bf16)
7121+
return LowerBFloatOp(Op, DAG);
70787122
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FTRUNC_MERGE_PASSTHRU);
70797123
case ISD::FSQRT:
7124+
if (Op.getScalarValueType() == MVT::bf16)
7125+
return LowerBFloatOp(Op, DAG);
70807126
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSQRT_MERGE_PASSTHRU);
70817127
case ISD::FABS:
70827128
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FABS_MERGE_PASSTHRU);
@@ -7242,12 +7288,20 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
72427288
case ISD::SUB:
72437289
return LowerToScalableOp(Op, DAG);
72447290
case ISD::FMAXIMUM:
7291+
if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
7292+
return LowerBFloatOp(Op, DAG);
72457293
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMAX_PRED);
72467294
case ISD::FMAXNUM:
7295+
if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
7296+
return LowerBFloatOp(Op, DAG);
72477297
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMAXNM_PRED);
72487298
case ISD::FMINIMUM:
7299+
if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
7300+
return LowerBFloatOp(Op, DAG);
72497301
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMIN_PRED);
72507302
case ISD::FMINNUM:
7303+
if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
7304+
return LowerBFloatOp(Op, DAG);
72517305
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMINNM_PRED);
72527306
case ISD::VSELECT:
72537307
return LowerFixedLengthVectorSelectToSVE(Op, DAG);
@@ -28466,6 +28520,40 @@ SDValue AArch64TargetLowering::LowerFixedLengthInsertVectorElt(
2846628520
return convertFromScalableVector(DAG, VT, ScalableRes);
2846728521
}
2846828522

28523+
// Lower bfloat16 operations by upcasting to float32, performing the operation
28524+
// and then downcasting the result back to bfloat16.
28525+
SDValue AArch64TargetLowering::LowerBFloatOp(SDValue Op,
28526+
SelectionDAG &DAG) const {
28527+
SDLoc DL(Op);
28528+
EVT VT = Op.getValueType();
28529+
assert(isTypeLegal(VT) && VT.isScalableVector() && "Unexpected type!");
28530+
28531+
// Split the vector and try again.
28532+
if (VT == MVT::nxv8bf16) {
28533+
SmallVector<SDValue, 4> LoOps, HiOps;
28534+
for (const SDValue &V : Op->op_values()) {
28535+
LoOps.push_back(DAG.getExtractSubvector(DL, MVT::nxv4bf16, V, 0));
28536+
HiOps.push_back(DAG.getExtractSubvector(DL, MVT::nxv4bf16, V, 4));
28537+
}
28538+
28539+
unsigned Opc = Op.getOpcode();
28540+
SDValue SplitOpLo = DAG.getNode(Opc, DL, MVT::nxv4bf16, LoOps);
28541+
SDValue SplitOpHi = DAG.getNode(Opc, DL, MVT::nxv4bf16, HiOps);
28542+
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi);
28543+
}
28544+
28545+
// Promote to float and try again.
28546+
EVT PromoteVT = VT.changeVectorElementType(MVT::f32);
28547+
28548+
SmallVector<SDValue, 4> Ops;
28549+
for (const SDValue &V : Op->op_values())
28550+
Ops.push_back(DAG.getNode(ISD::FP_EXTEND, DL, PromoteVT, V));
28551+
28552+
SDValue PromotedOp = DAG.getNode(Op.getOpcode(), DL, PromoteVT, Ops);
28553+
return DAG.getNode(ISD::FP_ROUND, DL, VT, PromotedOp,
28554+
DAG.getIntPtrConstant(0, DL, true));
28555+
}
28556+
2846928557
// Convert vector operation 'Op' to an equivalent predicated operation whereby
2847028558
// the original operation's type is used to construct a suitable predicate.
2847128559
// NOTE: The results for inactive lanes are undefined.

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,6 +1224,7 @@ class AArch64TargetLowering : public TargetLowering {
12241224
SDValue LowerWindowsDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
12251225
SDValue LowerInlineDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
12261226
SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
1227+
SDValue LowerBFloatOp(SDValue Op, SelectionDAG &DAG) const;
12271228

12281229
SDValue LowerAVG(SDValue Op, SelectionDAG &DAG, unsigned NewOp) const;
12291230

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,13 @@ let Predicates = [HasSVEorSME] in {
663663
defm FABS_ZPmZ : sve_int_un_pred_arit_1_fp<0b100, "fabs", AArch64fabs_mt>;
664664
defm FNEG_ZPmZ : sve_int_un_pred_arit_1_fp<0b101, "fneg", AArch64fneg_mt>;
665665

666+
foreach VT = [nxv2bf16, nxv4bf16, nxv8bf16] in {
667+
def : Pat<(VT (fabs VT:$op)),
668+
(AND_ZI $op, (i64 (logical_imm64_XFORM(i64 0x7fff7fff7fff7fff))))>;
669+
def : Pat<(VT (fneg VT:$op)),
670+
(EOR_ZI $op, (i64 (logical_imm64_XFORM(i64 0x8000800080008000))))>;
671+
}
672+
666673
// zext(cmpeq(x, splat(0))) -> cnot(x)
667674
def : Pat<(nxv16i8 (zext (nxv16i1 (AArch64setcc_z (nxv16i1 (SVEAllActive):$Pg), nxv16i8:$Op2, (SVEDup0), SETEQ)))),
668675
(CNOT_ZPmZ_B $Op2, $Pg, $Op2)>;

llvm/lib/Target/AArch64/SVEInstrFormats.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2299,6 +2299,8 @@ multiclass sve_fp_3op_u_zd_bfloat<bits<3> opc, string asm, SDPatternOperator op>
22992299
def NAME : sve_fp_3op_u_zd<0b00, opc, asm, ZPR16>;
23002300

23012301
def : SVE_2_Op_Pat<nxv8bf16, op, nxv8bf16, nxv8bf16, !cast<Instruction>(NAME)>;
2302+
def : SVE_2_Op_Pat<nxv4bf16, op, nxv4bf16, nxv4bf16, !cast<Instruction>(NAME)>;
2303+
def : SVE_2_Op_Pat<nxv2bf16, op, nxv2bf16, nxv2bf16, !cast<Instruction>(NAME)>;
23022304
}
23032305

23042306
multiclass sve_fp_3op_u_zd_ftsmul<bits<3> opc, string asm, SDPatternOperator op> {
@@ -9078,6 +9080,8 @@ multiclass sve_fp_bin_pred_bfloat<SDPatternOperator op> {
90789080
def _UNDEF : PredTwoOpPseudo<NAME, ZPR16, FalseLanesUndef>;
90799081

90809082
def : SVE_3_Op_Pat<nxv8bf16, op, nxv8i1, nxv8bf16, nxv8bf16, !cast<Pseudo>(NAME # _UNDEF)>;
9083+
def : SVE_3_Op_Pat<nxv4bf16, op, nxv4i1, nxv4bf16, nxv4bf16, !cast<Pseudo>(NAME # _UNDEF)>;
9084+
def : SVE_3_Op_Pat<nxv2bf16, op, nxv2i1, nxv2bf16, nxv2bf16, !cast<Pseudo>(NAME # _UNDEF)>;
90819085
}
90829086

90839087
// Predicated pseudo floating point three operand instructions.
@@ -9099,6 +9103,8 @@ multiclass sve_fp_3op_pred_bfloat<SDPatternOperator op> {
90999103
def _UNDEF : PredThreeOpPseudo<NAME, ZPR16, FalseLanesUndef>;
91009104

91019105
def : SVE_4_Op_Pat<nxv8bf16, op, nxv8i1, nxv8bf16, nxv8bf16, nxv8bf16, !cast<Instruction>(NAME # _UNDEF)>;
9106+
def : SVE_4_Op_Pat<nxv4bf16, op, nxv4i1, nxv4bf16, nxv4bf16, nxv4bf16, !cast<Instruction>(NAME # _UNDEF)>;
9107+
def : SVE_4_Op_Pat<nxv2bf16, op, nxv2i1, nxv2bf16, nxv2bf16, nxv2bf16, !cast<Instruction>(NAME # _UNDEF)>;
91029108
}
91039109

91049110
// Predicated pseudo integer two operand instructions.

0 commit comments

Comments
 (0)