Skip to content

Commit 68732ce

Browse files
[LLVM][CodeGen][SVE] Add isel for bfloat unordered reductions. (#143540)
The omissions are VECREDUCE_SEQ_* and MUL. The former goes down a different code path and the latter is unsupported across all element types.
1 parent 227cd56 commit 68732ce

File tree

4 files changed

+328
-15
lines changed

4 files changed

+328
-15
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,12 @@ class VectorLegalizer {
189189

190190
void PromoteSTRICT(SDNode *Node, SmallVectorImpl<SDValue> &Results);
191191

192+
/// Calculate the reduction using a type of higher precision and round the
193+
/// result to match the original type. Setting NonArithmetic signifies the
194+
/// rounding of the result does not affect its value.
195+
void PromoteFloatVECREDUCE(SDNode *Node, SmallVectorImpl<SDValue> &Results,
196+
bool NonArithmetic);
197+
192198
public:
193199
VectorLegalizer(SelectionDAG& dag) :
194200
DAG(dag), TLI(dag.getTargetLoweringInfo()) {}
@@ -500,20 +506,14 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
500506
case ISD::VECREDUCE_UMAX:
501507
case ISD::VECREDUCE_UMIN:
502508
case ISD::VECREDUCE_FADD:
503-
case ISD::VECREDUCE_FMUL:
504-
case ISD::VECTOR_FIND_LAST_ACTIVE:
505-
Action = TLI.getOperationAction(Node->getOpcode(),
506-
Node->getOperand(0).getValueType());
507-
break;
508509
case ISD::VECREDUCE_FMAX:
509-
case ISD::VECREDUCE_FMIN:
510510
case ISD::VECREDUCE_FMAXIMUM:
511+
case ISD::VECREDUCE_FMIN:
511512
case ISD::VECREDUCE_FMINIMUM:
513+
case ISD::VECREDUCE_FMUL:
514+
case ISD::VECTOR_FIND_LAST_ACTIVE:
512515
Action = TLI.getOperationAction(Node->getOpcode(),
513516
Node->getOperand(0).getValueType());
514-
// Defer non-vector results to LegalizeDAG.
515-
if (Action == TargetLowering::Promote)
516-
Action = TargetLowering::Legal;
517517
break;
518518
case ISD::VECREDUCE_SEQ_FADD:
519519
case ISD::VECREDUCE_SEQ_FMUL:
@@ -688,6 +688,24 @@ void VectorLegalizer::PromoteSTRICT(SDNode *Node,
688688
Results.push_back(Round.getValue(1));
689689
}
690690

691+
void VectorLegalizer::PromoteFloatVECREDUCE(SDNode *Node,
692+
SmallVectorImpl<SDValue> &Results,
693+
bool NonArithmetic) {
694+
MVT OpVT = Node->getOperand(0).getSimpleValueType();
695+
assert(OpVT.isFloatingPoint() && "Expected floating point reduction!");
696+
MVT NewOpVT = TLI.getTypeToPromoteTo(Node->getOpcode(), OpVT);
697+
698+
SDLoc DL(Node);
699+
SDValue NewOp = DAG.getNode(ISD::FP_EXTEND, DL, NewOpVT, Node->getOperand(0));
700+
SDValue Rdx =
701+
DAG.getNode(Node->getOpcode(), DL, NewOpVT.getVectorElementType(), NewOp,
702+
Node->getFlags());
703+
SDValue Res =
704+
DAG.getNode(ISD::FP_ROUND, DL, Node->getValueType(0), Rdx,
705+
DAG.getIntPtrConstant(NonArithmetic, DL, /*isTarget=*/true));
706+
Results.push_back(Res);
707+
}
708+
691709
void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
692710
// For a few operations there is a specific concept for promotion based on
693711
// the operand's type.
@@ -719,6 +737,15 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
719737
case ISD::STRICT_FMA:
720738
PromoteSTRICT(Node, Results);
721739
return;
740+
case ISD::VECREDUCE_FADD:
741+
PromoteFloatVECREDUCE(Node, Results, /*NonArithmetic=*/false);
742+
return;
743+
case ISD::VECREDUCE_FMAX:
744+
case ISD::VECREDUCE_FMAXIMUM:
745+
case ISD::VECREDUCE_FMIN:
746+
case ISD::VECREDUCE_FMINIMUM:
747+
PromoteFloatVECREDUCE(Node, Results, /*NonArithmetic=*/true);
748+
return;
722749
case ISD::FP_ROUND:
723750
case ISD::FP_EXTEND:
724751
// These operations are used to do promotion so they can't be promoted

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11413,13 +11413,9 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
1141311413
SDValue Op = Node->getOperand(0);
1141411414
EVT VT = Op.getValueType();
1141511415

11416-
if (VT.isScalableVector())
11417-
report_fatal_error(
11418-
"Expanding reductions for scalable vectors is undefined.");
11419-
1142011416
// Try to use a shuffle reduction for power of two vectors.
1142111417
if (VT.isPow2VectorType()) {
11422-
while (VT.getVectorNumElements() > 1) {
11418+
while (VT.getVectorElementCount().isKnownMultipleOf(2)) {
1142311419
EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
1142411420
if (!isOperationLegalOrCustom(BaseOpcode, HalfVT))
1142511421
break;
@@ -11428,9 +11424,18 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
1142811424
std::tie(Lo, Hi) = DAG.SplitVector(Op, dl);
1142911425
Op = DAG.getNode(BaseOpcode, dl, HalfVT, Lo, Hi, Node->getFlags());
1143011426
VT = HalfVT;
11427+
11428+
// Stop if splitting is enough to make the reduction legal.
11429+
if (isOperationLegalOrCustom(Node->getOpcode(), HalfVT))
11430+
return DAG.getNode(Node->getOpcode(), dl, Node->getValueType(0), Op,
11431+
Node->getFlags());
1143111432
}
1143211433
}
1143311434

11435+
if (VT.isScalableVector())
11436+
reportFatalInternalError(
11437+
"Expanding reductions for scalable vectors is undefined.");
11438+
1143411439
EVT EltVT = VT.getVectorElementType();
1143511440
unsigned NumElts = VT.getVectorNumElements();
1143611441

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1762,7 +1762,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
17621762

17631763
for (auto Opcode :
17641764
{ISD::FCEIL, ISD::FDIV, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
1765-
ISD::FROUND, ISD::FROUNDEVEN, ISD::FSQRT, ISD::FTRUNC, ISD::SETCC}) {
1765+
ISD::FROUND, ISD::FROUNDEVEN, ISD::FSQRT, ISD::FTRUNC, ISD::SETCC,
1766+
ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMAXIMUM,
1767+
ISD::VECREDUCE_FMIN, ISD::VECREDUCE_FMINIMUM}) {
17661768
setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32);
17671769
setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
17681770
setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);

0 commit comments

Comments
 (0)