Skip to content

[LLVM][CodeGen][SVE] Add isel for bfloat unordered reductions. #143540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ class VectorLegalizer {

void PromoteSTRICT(SDNode *Node, SmallVectorImpl<SDValue> &Results);

/// Calculate the reduction using a type of higher precision and round the
/// result to match the original type. Setting NonArithmetic signifies the
/// rounding of the result does not affect its value.
void PromoteFloatVECREDUCE(SDNode *Node, SmallVectorImpl<SDValue> &Results,
bool NonArithmetic);

public:
VectorLegalizer(SelectionDAG& dag) :
DAG(dag), TLI(dag.getTargetLoweringInfo()) {}
Expand Down Expand Up @@ -500,20 +506,14 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::VECREDUCE_UMAX:
case ISD::VECREDUCE_UMIN:
case ISD::VECREDUCE_FADD:
case ISD::VECREDUCE_FMUL:
case ISD::VECTOR_FIND_LAST_ACTIVE:
Action = TLI.getOperationAction(Node->getOpcode(),
Node->getOperand(0).getValueType());
break;
case ISD::VECREDUCE_FMAX:
case ISD::VECREDUCE_FMIN:
case ISD::VECREDUCE_FMAXIMUM:
case ISD::VECREDUCE_FMIN:
case ISD::VECREDUCE_FMINIMUM:
case ISD::VECREDUCE_FMUL:
case ISD::VECTOR_FIND_LAST_ACTIVE:
Action = TLI.getOperationAction(Node->getOpcode(),
Node->getOperand(0).getValueType());
// Defer non-vector results to LegalizeDAG.
if (Action == TargetLowering::Promote)
Action = TargetLowering::Legal;
break;
case ISD::VECREDUCE_SEQ_FADD:
case ISD::VECREDUCE_SEQ_FMUL:
Expand Down Expand Up @@ -688,6 +688,24 @@ void VectorLegalizer::PromoteSTRICT(SDNode *Node,
Results.push_back(Round.getValue(1));
}

void VectorLegalizer::PromoteFloatVECREDUCE(SDNode *Node,
SmallVectorImpl<SDValue> &Results,
bool NonArithmetic) {
MVT OpVT = Node->getOperand(0).getSimpleValueType();
assert(OpVT.isFloatingPoint() && "Expected floating point reduction!");
MVT NewOpVT = TLI.getTypeToPromoteTo(Node->getOpcode(), OpVT);

SDLoc DL(Node);
SDValue NewOp = DAG.getNode(ISD::FP_EXTEND, DL, NewOpVT, Node->getOperand(0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the strictfp case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does strictfp apply here? There are no STRICT_VECREDUCE_ nodes. There are the ordered VECREDUCE_SEQ_ nodes, but they go down a different path so are not covered by this PR.

Copy link
Contributor

@arsenm arsenm Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This pushes on the broader issue with the current strictfp strategy where we need to duplicate all possible FP intrinsics and we're missing them here.

So for these non-strict intrinsics it's not an issue, the issue is we don't have strict versions of these

SDValue Rdx =
DAG.getNode(Node->getOpcode(), DL, NewOpVT.getVectorElementType(), NewOp,
Node->getFlags());
SDValue Res =
DAG.getNode(ISD::FP_ROUND, DL, Node->getValueType(0), Rdx,
DAG.getIntPtrConstant(NonArithmetic, DL, /*isTarget=*/true));
Results.push_back(Res);
}

void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
// For a few operations there is a specific concept for promotion based on
// the operand's type.
Expand Down Expand Up @@ -719,6 +737,15 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
case ISD::STRICT_FMA:
PromoteSTRICT(Node, Results);
return;
case ISD::VECREDUCE_FADD:
PromoteFloatVECREDUCE(Node, Results, /*NonArithmetic=*/false);
return;
case ISD::VECREDUCE_FMAX:
case ISD::VECREDUCE_FMAXIMUM:
case ISD::VECREDUCE_FMIN:
case ISD::VECREDUCE_FMINIMUM:
PromoteFloatVECREDUCE(Node, Results, /*NonArithmetic=*/true);
return;
case ISD::FP_ROUND:
case ISD::FP_EXTEND:
// These operations are used to do promotion so they can't be promoted
Expand Down
15 changes: 10 additions & 5 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11412,13 +11412,9 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
SDValue Op = Node->getOperand(0);
EVT VT = Op.getValueType();

if (VT.isScalableVector())
report_fatal_error(
"Expanding reductions for scalable vectors is undefined.");

// Try to use a shuffle reduction for power of two vectors.
if (VT.isPow2VectorType()) {
while (VT.getVectorNumElements() > 1) {
while (VT.getVectorElementCount().isKnownMultipleOf(2)) {
EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
if (!isOperationLegalOrCustom(BaseOpcode, HalfVT))
break;
Expand All @@ -11427,9 +11423,18 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
std::tie(Lo, Hi) = DAG.SplitVector(Op, dl);
Op = DAG.getNode(BaseOpcode, dl, HalfVT, Lo, Hi, Node->getFlags());
VT = HalfVT;

// Stop if splitting is enough to make the reduction legal.
if (isOperationLegalOrCustom(Node->getOpcode(), HalfVT))
return DAG.getNode(Node->getOpcode(), dl, Node->getValueType(0), Op,
Node->getFlags());
}
}

if (VT.isScalableVector())
reportFatalInternalError(
"Expanding reductions for scalable vectors is undefined.");

EVT EltVT = VT.getVectorElementType();
unsigned NumElts = VT.getVectorNumElements();

Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1780,7 +1780,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,

for (auto Opcode :
{ISD::FCEIL, ISD::FDIV, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
ISD::FROUND, ISD::FROUNDEVEN, ISD::FSQRT, ISD::FTRUNC, ISD::SETCC}) {
ISD::FROUND, ISD::FROUNDEVEN, ISD::FSQRT, ISD::FTRUNC, ISD::SETCC,
ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMAXIMUM,
ISD::VECREDUCE_FMIN, ISD::VECREDUCE_FMINIMUM}) {
setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32);
setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
Expand Down
Loading
Loading