Skip to content

[LLVM][CodeGen] Add lowering for scalable vector bfloat operations. #109803

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -5616,6 +5616,10 @@ class TargetLowering : public TargetLoweringBase {
return true;
}

// Expand vector operation by dividing it into smaller length operations and
// joining their results. SDValue() is returned when expansion did not happen.
SDValue expandVectorNaryOpBySplitting(SDNode *Node, SelectionDAG &DAG) const;

private:
SDValue foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
const SDLoc &DL, DAGCombinerInfo &DCI) const;
Expand Down
23 changes: 23 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1197,6 +1197,24 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
case ISD::UCMP:
Results.push_back(TLI.expandCMP(Node, DAG));
return;

case ISD::FADD:
case ISD::FMUL:
case ISD::FMA:
case ISD::FDIV:
case ISD::FCEIL:
case ISD::FFLOOR:
case ISD::FNEARBYINT:
case ISD::FRINT:
case ISD::FROUND:
case ISD::FROUNDEVEN:
case ISD::FTRUNC:
case ISD::FSQRT:
if (SDValue Expanded = TLI.expandVectorNaryOpBySplitting(Node, DAG)) {
Results.push_back(Expanded);
return;
}
break;
}

SDValue Unrolled = DAG.UnrollVectorOp(Node);
Expand Down Expand Up @@ -1874,6 +1892,11 @@ void VectorLegalizer::ExpandFSUB(SDNode *Node,
TLI.isOperationLegalOrCustom(ISD::FADD, VT))
return; // Defer to LegalizeDAG

if (SDValue Expanded = TLI.expandVectorNaryOpBySplitting(Node, DAG)) {
Results.push_back(Expanded);
return;
}

SDValue Tmp = DAG.UnrollVectorOp(Node);
Results.push_back(Tmp);
}
Expand Down
46 changes: 42 additions & 4 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8440,15 +8440,18 @@ TargetLowering::createSelectForFMINNUM_FMAXNUM(SDNode *Node,

SDValue TargetLowering::expandFMINNUM_FMAXNUM(SDNode *Node,
SelectionDAG &DAG) const {
SDLoc dl(Node);
unsigned NewOp = Node->getOpcode() == ISD::FMINNUM ?
ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
EVT VT = Node->getValueType(0);
if (SDValue Expanded = expandVectorNaryOpBySplitting(Node, DAG))
return Expanded;

EVT VT = Node->getValueType(0);
if (VT.isScalableVector())
report_fatal_error(
"Expanding fminnum/fmaxnum for scalable vectors is undefined.");

SDLoc dl(Node);
unsigned NewOp =
Node->getOpcode() == ISD::FMINNUM ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;

if (isOperationLegalOrCustom(NewOp, VT)) {
SDValue Quiet0 = Node->getOperand(0);
SDValue Quiet1 = Node->getOperand(1);
Expand Down Expand Up @@ -8493,6 +8496,9 @@ SDValue TargetLowering::expandFMINNUM_FMAXNUM(SDNode *Node,

SDValue TargetLowering::expandFMINIMUM_FMAXIMUM(SDNode *N,
SelectionDAG &DAG) const {
if (SDValue Expanded = expandVectorNaryOpBySplitting(N, DAG))
return Expanded;

SDLoc DL(N);
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
Expand Down Expand Up @@ -11917,3 +11923,35 @@ bool TargetLowering::LegalizeSetCCCondCode(SelectionDAG &DAG, EVT VT,
}
return false;
}

SDValue TargetLowering::expandVectorNaryOpBySplitting(SDNode *Node,
SelectionDAG &DAG) const {
EVT VT = Node->getValueType(0);
// Despite its documentation, GetSplitDestVTs will assert if VT cannot be
// split into two equal parts.
if (!VT.isVector() || !VT.getVectorElementCount().isKnownMultipleOf(2))
return SDValue();

// Restrict expansion to cases where both parts can be concatenated.
auto [LoVT, HiVT] = DAG.GetSplitDestVTs(VT);
if (LoVT != HiVT || !isTypeLegal(LoVT))
return SDValue();

SDLoc DL(Node);
unsigned Opcode = Node->getOpcode();

// Don't expand if the result is likely to be unrolled anyway.
if (!isOperationLegalOrCustomOrPromote(Opcode, LoVT))
return SDValue();

SmallVector<SDValue, 4> LoOps, HiOps;
for (const SDValue &V : Node->op_values()) {
auto [Lo, Hi] = DAG.SplitVector(V, DL, LoVT, HiVT);
LoOps.push_back(Lo);
HiOps.push_back(Hi);
}

SDValue SplitOpLo = DAG.getNode(Opcode, DL, LoVT, LoOps);
SDValue SplitOpHi = DAG.getNode(Opcode, DL, HiVT, HiOps);
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi);
}
30 changes: 30 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1663,12 +1663,42 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
setOperationAction(ISD::BITCAST, VT, Custom);
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
setOperationAction(ISD::FABS, VT, Legal);
setOperationAction(ISD::FNEG, VT, Legal);
setOperationAction(ISD::FP_EXTEND, VT, Custom);
setOperationAction(ISD::FP_ROUND, VT, Custom);
setOperationAction(ISD::MLOAD, VT, Custom);
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);

if (Subtarget->hasSVEB16B16()) {
setOperationAction(ISD::FADD, VT, Legal);
setOperationAction(ISD::FMA, VT, Custom);
setOperationAction(ISD::FMAXIMUM, VT, Custom);
setOperationAction(ISD::FMAXNUM, VT, Custom);
setOperationAction(ISD::FMINIMUM, VT, Custom);
setOperationAction(ISD::FMINNUM, VT, Custom);
setOperationAction(ISD::FMUL, VT, Legal);
setOperationAction(ISD::FSUB, VT, Legal);
}
}

for (auto Opcode :
{ISD::FCEIL, ISD::FDIV, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
ISD::FROUND, ISD::FROUNDEVEN, ISD::FSQRT, ISD::FTRUNC}) {
setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32);
setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
setOperationAction(Opcode, MVT::nxv8bf16, Expand);
}

if (!Subtarget->hasSVEB16B16()) {
for (auto Opcode : {ISD::FADD, ISD::FMA, ISD::FMAXIMUM, ISD::FMAXNUM,
ISD::FMINIMUM, ISD::FMINNUM, ISD::FMUL, ISD::FSUB}) {
setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32);
setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
setOperationAction(Opcode, MVT::nxv8bf16, Expand);
}
}

setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,15 @@ let Predicates = [HasSVEorSME] in {
defm FABS_ZPmZ : sve_int_un_pred_arit_1_fp<0b100, "fabs", AArch64fabs_mt>;
defm FNEG_ZPmZ : sve_int_un_pred_arit_1_fp<0b101, "fneg", AArch64fneg_mt>;

foreach VT = [nxv2bf16, nxv4bf16, nxv8bf16] in {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: May be worth a comment clarifying that you're zeroing or inverting the sign bit.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// No dedicated instruction, so just clear the sign bit.
def : Pat<(VT (fabs VT:$op)),
(AND_ZI $op, (i64 (logical_imm64_XFORM(i64 0x7fff7fff7fff7fff))))>;
// No dedicated instruction, so just invert the sign bit.
def : Pat<(VT (fneg VT:$op)),
(EOR_ZI $op, (i64 (logical_imm64_XFORM(i64 0x8000800080008000))))>;
}

// zext(cmpeq(x, splat(0))) -> cnot(x)
def : Pat<(nxv16i8 (zext (nxv16i1 (AArch64setcc_z (nxv16i1 (SVEAllActive):$Pg), nxv16i8:$Op2, (SVEDup0), SETEQ)))),
(CNOT_ZPmZ_B $Op2, $Pg, $Op2)>;
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/AArch64/SVEInstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -2299,6 +2299,8 @@ multiclass sve_fp_3op_u_zd_bfloat<bits<3> opc, string asm, SDPatternOperator op>
def NAME : sve_fp_3op_u_zd<0b00, opc, asm, ZPR16>;

def : SVE_2_Op_Pat<nxv8bf16, op, nxv8bf16, nxv8bf16, !cast<Instruction>(NAME)>;
def : SVE_2_Op_Pat<nxv4bf16, op, nxv4bf16, nxv4bf16, !cast<Instruction>(NAME)>;
def : SVE_2_Op_Pat<nxv2bf16, op, nxv2bf16, nxv2bf16, !cast<Instruction>(NAME)>;
}

multiclass sve_fp_3op_u_zd_ftsmul<bits<3> opc, string asm, SDPatternOperator op> {
Expand Down Expand Up @@ -9078,6 +9080,8 @@ multiclass sve_fp_bin_pred_bfloat<SDPatternOperator op> {
def _UNDEF : PredTwoOpPseudo<NAME, ZPR16, FalseLanesUndef>;

def : SVE_3_Op_Pat<nxv8bf16, op, nxv8i1, nxv8bf16, nxv8bf16, !cast<Pseudo>(NAME # _UNDEF)>;
def : SVE_3_Op_Pat<nxv4bf16, op, nxv4i1, nxv4bf16, nxv4bf16, !cast<Pseudo>(NAME # _UNDEF)>;
def : SVE_3_Op_Pat<nxv2bf16, op, nxv2i1, nxv2bf16, nxv2bf16, !cast<Pseudo>(NAME # _UNDEF)>;
}

// Predicated pseudo floating point three operand instructions.
Expand All @@ -9099,6 +9103,8 @@ multiclass sve_fp_3op_pred_bfloat<SDPatternOperator op> {
def _UNDEF : PredThreeOpPseudo<NAME, ZPR16, FalseLanesUndef>;

def : SVE_4_Op_Pat<nxv8bf16, op, nxv8i1, nxv8bf16, nxv8bf16, nxv8bf16, !cast<Instruction>(NAME # _UNDEF)>;
def : SVE_4_Op_Pat<nxv4bf16, op, nxv4i1, nxv4bf16, nxv4bf16, nxv4bf16, !cast<Instruction>(NAME # _UNDEF)>;
def : SVE_4_Op_Pat<nxv2bf16, op, nxv2i1, nxv2bf16, nxv2bf16, nxv2bf16, !cast<Instruction>(NAME # _UNDEF)>;
}

// Predicated pseudo integer two operand instructions.
Expand Down
Loading
Loading