Skip to content

Commit 02dd6b1

Browse files
[LLVM][CodeGen] Add lowering for scalable vector bfloat operations. (#109803)
Specifically: fabs, fadd, fceil, fdiv, ffloor, fma, fmax, fmaxnm, fmin, fminnm, fmul, fnearbyint, fneg, frint, fround, froundeven, fsub, fsqrt & ftrunc
1 parent 8b6e1dc commit 02dd6b1

File tree

9 files changed

+1234
-31
lines changed

9 files changed

+1234
-31
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5616,6 +5616,10 @@ class TargetLowering : public TargetLoweringBase {
56165616
return true;
56175617
}
56185618

5619+
// Expand vector operation by dividing it into smaller length operations and
5620+
// joining their results. SDValue() is returned when expansion did not happen.
5621+
SDValue expandVectorNaryOpBySplitting(SDNode *Node, SelectionDAG &DAG) const;
5622+
56195623
private:
56205624
SDValue foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
56215625
const SDLoc &DL, DAGCombinerInfo &DCI) const;

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,6 +1197,24 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
11971197
case ISD::UCMP:
11981198
Results.push_back(TLI.expandCMP(Node, DAG));
11991199
return;
1200+
1201+
case ISD::FADD:
1202+
case ISD::FMUL:
1203+
case ISD::FMA:
1204+
case ISD::FDIV:
1205+
case ISD::FCEIL:
1206+
case ISD::FFLOOR:
1207+
case ISD::FNEARBYINT:
1208+
case ISD::FRINT:
1209+
case ISD::FROUND:
1210+
case ISD::FROUNDEVEN:
1211+
case ISD::FTRUNC:
1212+
case ISD::FSQRT:
1213+
if (SDValue Expanded = TLI.expandVectorNaryOpBySplitting(Node, DAG)) {
1214+
Results.push_back(Expanded);
1215+
return;
1216+
}
1217+
break;
12001218
}
12011219

12021220
SDValue Unrolled = DAG.UnrollVectorOp(Node);
@@ -1885,6 +1903,11 @@ void VectorLegalizer::ExpandFSUB(SDNode *Node,
18851903
TLI.isOperationLegalOrCustom(ISD::FADD, VT))
18861904
return; // Defer to LegalizeDAG
18871905

1906+
if (SDValue Expanded = TLI.expandVectorNaryOpBySplitting(Node, DAG)) {
1907+
Results.push_back(Expanded);
1908+
return;
1909+
}
1910+
18881911
SDValue Tmp = DAG.UnrollVectorOp(Node);
18891912
Results.push_back(Tmp);
18901913
}

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8440,15 +8440,18 @@ TargetLowering::createSelectForFMINNUM_FMAXNUM(SDNode *Node,
84408440

84418441
SDValue TargetLowering::expandFMINNUM_FMAXNUM(SDNode *Node,
84428442
SelectionDAG &DAG) const {
8443-
SDLoc dl(Node);
8444-
unsigned NewOp = Node->getOpcode() == ISD::FMINNUM ?
8445-
ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
8446-
EVT VT = Node->getValueType(0);
8443+
if (SDValue Expanded = expandVectorNaryOpBySplitting(Node, DAG))
8444+
return Expanded;
84478445

8446+
EVT VT = Node->getValueType(0);
84488447
if (VT.isScalableVector())
84498448
report_fatal_error(
84508449
"Expanding fminnum/fmaxnum for scalable vectors is undefined.");
84518450

8451+
SDLoc dl(Node);
8452+
unsigned NewOp =
8453+
Node->getOpcode() == ISD::FMINNUM ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
8454+
84528455
if (isOperationLegalOrCustom(NewOp, VT)) {
84538456
SDValue Quiet0 = Node->getOperand(0);
84548457
SDValue Quiet1 = Node->getOperand(1);
@@ -8493,6 +8496,9 @@ SDValue TargetLowering::expandFMINNUM_FMAXNUM(SDNode *Node,
84938496

84948497
SDValue TargetLowering::expandFMINIMUM_FMAXIMUM(SDNode *N,
84958498
SelectionDAG &DAG) const {
8499+
if (SDValue Expanded = expandVectorNaryOpBySplitting(N, DAG))
8500+
return Expanded;
8501+
84968502
SDLoc DL(N);
84978503
SDValue LHS = N->getOperand(0);
84988504
SDValue RHS = N->getOperand(1);
@@ -11920,3 +11926,35 @@ bool TargetLowering::LegalizeSetCCCondCode(SelectionDAG &DAG, EVT VT,
1192011926
}
1192111927
return false;
1192211928
}
11929+
11930+
SDValue TargetLowering::expandVectorNaryOpBySplitting(SDNode *Node,
11931+
SelectionDAG &DAG) const {
11932+
EVT VT = Node->getValueType(0);
11933+
// Despite its documentation, GetSplitDestVTs will assert if VT cannot be
11934+
// split into two equal parts.
11935+
if (!VT.isVector() || !VT.getVectorElementCount().isKnownMultipleOf(2))
11936+
return SDValue();
11937+
11938+
// Restrict expansion to cases where both parts can be concatenated.
11939+
auto [LoVT, HiVT] = DAG.GetSplitDestVTs(VT);
11940+
if (LoVT != HiVT || !isTypeLegal(LoVT))
11941+
return SDValue();
11942+
11943+
SDLoc DL(Node);
11944+
unsigned Opcode = Node->getOpcode();
11945+
11946+
// Don't expand if the result is likely to be unrolled anyway.
11947+
if (!isOperationLegalOrCustomOrPromote(Opcode, LoVT))
11948+
return SDValue();
11949+
11950+
SmallVector<SDValue, 4> LoOps, HiOps;
11951+
for (const SDValue &V : Node->op_values()) {
11952+
auto [Lo, Hi] = DAG.SplitVector(V, DL, LoVT, HiVT);
11953+
LoOps.push_back(Lo);
11954+
HiOps.push_back(Hi);
11955+
}
11956+
11957+
SDValue SplitOpLo = DAG.getNode(Opcode, DL, LoVT, LoOps);
11958+
SDValue SplitOpHi = DAG.getNode(Opcode, DL, HiVT, HiOps);
11959+
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi);
11960+
}

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,12 +1663,42 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
16631663
for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
16641664
setOperationAction(ISD::BITCAST, VT, Custom);
16651665
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
1666+
setOperationAction(ISD::FABS, VT, Legal);
1667+
setOperationAction(ISD::FNEG, VT, Legal);
16661668
setOperationAction(ISD::FP_EXTEND, VT, Custom);
16671669
setOperationAction(ISD::FP_ROUND, VT, Custom);
16681670
setOperationAction(ISD::MLOAD, VT, Custom);
16691671
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
16701672
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
16711673
setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
1674+
1675+
if (Subtarget->hasSVEB16B16()) {
1676+
setOperationAction(ISD::FADD, VT, Legal);
1677+
setOperationAction(ISD::FMA, VT, Custom);
1678+
setOperationAction(ISD::FMAXIMUM, VT, Custom);
1679+
setOperationAction(ISD::FMAXNUM, VT, Custom);
1680+
setOperationAction(ISD::FMINIMUM, VT, Custom);
1681+
setOperationAction(ISD::FMINNUM, VT, Custom);
1682+
setOperationAction(ISD::FMUL, VT, Legal);
1683+
setOperationAction(ISD::FSUB, VT, Legal);
1684+
}
1685+
}
1686+
1687+
for (auto Opcode :
1688+
{ISD::FCEIL, ISD::FDIV, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
1689+
ISD::FROUND, ISD::FROUNDEVEN, ISD::FSQRT, ISD::FTRUNC}) {
1690+
setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32);
1691+
setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
1692+
setOperationAction(Opcode, MVT::nxv8bf16, Expand);
1693+
}
1694+
1695+
if (!Subtarget->hasSVEB16B16()) {
1696+
for (auto Opcode : {ISD::FADD, ISD::FMA, ISD::FMAXIMUM, ISD::FMAXNUM,
1697+
ISD::FMINIMUM, ISD::FMINNUM, ISD::FMUL, ISD::FSUB}) {
1698+
setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32);
1699+
setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
1700+
setOperationAction(Opcode, MVT::nxv8bf16, Expand);
1701+
}
16721702
}
16731703

16741704
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,15 @@ let Predicates = [HasSVEorSME] in {
663663
defm FABS_ZPmZ : sve_int_un_pred_arit_1_fp<0b100, "fabs", AArch64fabs_mt>;
664664
defm FNEG_ZPmZ : sve_int_un_pred_arit_1_fp<0b101, "fneg", AArch64fneg_mt>;
665665

666+
foreach VT = [nxv2bf16, nxv4bf16, nxv8bf16] in {
667+
// No dedicated instruction, so just clear the sign bit.
668+
def : Pat<(VT (fabs VT:$op)),
669+
(AND_ZI $op, (i64 (logical_imm64_XFORM(i64 0x7fff7fff7fff7fff))))>;
670+
// No dedicated instruction, so just invert the sign bit.
671+
def : Pat<(VT (fneg VT:$op)),
672+
(EOR_ZI $op, (i64 (logical_imm64_XFORM(i64 0x8000800080008000))))>;
673+
}
674+
666675
// zext(cmpeq(x, splat(0))) -> cnot(x)
667676
def : Pat<(nxv16i8 (zext (nxv16i1 (AArch64setcc_z (nxv16i1 (SVEAllActive):$Pg), nxv16i8:$Op2, (SVEDup0), SETEQ)))),
668677
(CNOT_ZPmZ_B $Op2, $Pg, $Op2)>;

llvm/lib/Target/AArch64/SVEInstrFormats.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2299,6 +2299,8 @@ multiclass sve_fp_3op_u_zd_bfloat<bits<3> opc, string asm, SDPatternOperator op>
22992299
def NAME : sve_fp_3op_u_zd<0b00, opc, asm, ZPR16>;
23002300

23012301
def : SVE_2_Op_Pat<nxv8bf16, op, nxv8bf16, nxv8bf16, !cast<Instruction>(NAME)>;
2302+
def : SVE_2_Op_Pat<nxv4bf16, op, nxv4bf16, nxv4bf16, !cast<Instruction>(NAME)>;
2303+
def : SVE_2_Op_Pat<nxv2bf16, op, nxv2bf16, nxv2bf16, !cast<Instruction>(NAME)>;
23022304
}
23032305

23042306
multiclass sve_fp_3op_u_zd_ftsmul<bits<3> opc, string asm, SDPatternOperator op> {
@@ -9078,6 +9080,8 @@ multiclass sve_fp_bin_pred_bfloat<SDPatternOperator op> {
90789080
def _UNDEF : PredTwoOpPseudo<NAME, ZPR16, FalseLanesUndef>;
90799081

90809082
def : SVE_3_Op_Pat<nxv8bf16, op, nxv8i1, nxv8bf16, nxv8bf16, !cast<Pseudo>(NAME # _UNDEF)>;
9083+
def : SVE_3_Op_Pat<nxv4bf16, op, nxv4i1, nxv4bf16, nxv4bf16, !cast<Pseudo>(NAME # _UNDEF)>;
9084+
def : SVE_3_Op_Pat<nxv2bf16, op, nxv2i1, nxv2bf16, nxv2bf16, !cast<Pseudo>(NAME # _UNDEF)>;
90819085
}
90829086

90839087
// Predicated pseudo floating point three operand instructions.
@@ -9099,6 +9103,8 @@ multiclass sve_fp_3op_pred_bfloat<SDPatternOperator op> {
90999103
def _UNDEF : PredThreeOpPseudo<NAME, ZPR16, FalseLanesUndef>;
91009104

91019105
def : SVE_4_Op_Pat<nxv8bf16, op, nxv8i1, nxv8bf16, nxv8bf16, nxv8bf16, !cast<Instruction>(NAME # _UNDEF)>;
9106+
def : SVE_4_Op_Pat<nxv4bf16, op, nxv4i1, nxv4bf16, nxv4bf16, nxv4bf16, !cast<Instruction>(NAME # _UNDEF)>;
9107+
def : SVE_4_Op_Pat<nxv2bf16, op, nxv2i1, nxv2bf16, nxv2bf16, nxv2bf16, !cast<Instruction>(NAME # _UNDEF)>;
91029108
}
91039109

91049110
// Predicated pseudo integer two operand instructions.

0 commit comments

Comments
 (0)