Skip to content

Commit ffa2810

Browse files
authored
[RISCV] Optimize lowering of VECREDUCE_FMINIMUM/VECREDUCE_FMAXIMUM. (#85165)
Use a normal min/max reduction that doesn't propagate nans and force the result to nan at the end if any elements were nan.
1 parent 8c4546f commit ffa2810

File tree

2 files changed

+326
-1098
lines changed

2 files changed

+326
-1098
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
717717

718718
static const unsigned FloatingPointVecReduceOps[] = {
719719
ISD::VECREDUCE_FADD, ISD::VECREDUCE_SEQ_FADD, ISD::VECREDUCE_FMIN,
720-
ISD::VECREDUCE_FMAX};
720+
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMINIMUM, ISD::VECREDUCE_FMAXIMUM};
721721

722722
if (!Subtarget.is64Bit()) {
723723
// We must custom-lower certain vXi64 operations on RV32 due to the vector
@@ -6541,6 +6541,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
65416541
case ISD::VECREDUCE_SEQ_FADD:
65426542
case ISD::VECREDUCE_FMIN:
65436543
case ISD::VECREDUCE_FMAX:
6544+
case ISD::VECREDUCE_FMAXIMUM:
6545+
case ISD::VECREDUCE_FMINIMUM:
65446546
return lowerFPVECREDUCE(Op, DAG);
65456547
case ISD::VP_REDUCE_ADD:
65466548
case ISD::VP_REDUCE_UMAX:
@@ -9541,14 +9543,17 @@ getRVVFPReductionOpAndOperands(SDValue Op, SelectionDAG &DAG, EVT EltVT,
95419543
case ISD::VECREDUCE_SEQ_FADD:
95429544
return std::make_tuple(RISCVISD::VECREDUCE_SEQ_FADD_VL, Op.getOperand(1),
95439545
Op.getOperand(0));
9546+
case ISD::VECREDUCE_FMINIMUM:
9547+
case ISD::VECREDUCE_FMAXIMUM:
95449548
case ISD::VECREDUCE_FMIN:
95459549
case ISD::VECREDUCE_FMAX: {
95469550
SDValue Front =
95479551
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Op.getOperand(0),
95489552
DAG.getVectorIdxConstant(0, DL));
9549-
unsigned RVVOpc = (Opcode == ISD::VECREDUCE_FMIN)
9550-
? RISCVISD::VECREDUCE_FMIN_VL
9551-
: RISCVISD::VECREDUCE_FMAX_VL;
9553+
unsigned RVVOpc =
9554+
(Opcode == ISD::VECREDUCE_FMIN || Opcode == ISD::VECREDUCE_FMINIMUM)
9555+
? RISCVISD::VECREDUCE_FMIN_VL
9556+
: RISCVISD::VECREDUCE_FMAX_VL;
95529557
return std::make_tuple(RVVOpc, Op.getOperand(0), Front);
95539558
}
95549559
}
@@ -9571,9 +9576,30 @@ SDValue RISCVTargetLowering::lowerFPVECREDUCE(SDValue Op,
95719576
VectorVal = convertToScalableVector(ContainerVT, VectorVal, DAG, Subtarget);
95729577
}
95739578

9579+
MVT ResVT = Op.getSimpleValueType();
95749580
auto [Mask, VL] = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget);
9575-
return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), ScalarVal,
9576-
VectorVal, Mask, VL, DL, DAG, Subtarget);
9581+
SDValue Res = lowerReductionSeq(RVVOpcode, ResVT, ScalarVal, VectorVal, Mask,
9582+
VL, DL, DAG, Subtarget);
9583+
if (Op.getOpcode() != ISD::VECREDUCE_FMINIMUM &&
9584+
Op.getOpcode() != ISD::VECREDUCE_FMAXIMUM)
9585+
return Res;
9586+
9587+
if (Op->getFlags().hasNoNaNs())
9588+
return Res;
9589+
9590+
// Force output to NaN if any element is Nan.
9591+
SDValue IsNan =
9592+
DAG.getNode(RISCVISD::SETCC_VL, DL, Mask.getValueType(),
9593+
{VectorVal, VectorVal, DAG.getCondCode(ISD::SETNE),
9594+
DAG.getUNDEF(Mask.getValueType()), Mask, VL});
9595+
MVT XLenVT = Subtarget.getXLenVT();
9596+
SDValue CPop = DAG.getNode(RISCVISD::VCPOP_VL, DL, XLenVT, IsNan, Mask, VL);
9597+
SDValue NoNaNs = DAG.getSetCC(DL, XLenVT, CPop,
9598+
DAG.getConstant(0, DL, XLenVT), ISD::SETEQ);
9599+
return DAG.getSelect(
9600+
DL, ResVT, NoNaNs, Res,
9601+
DAG.getConstantFP(APFloat::getNaN(DAG.EVTToAPFloatSemantics(ResVT)), DL,
9602+
ResVT));
95779603
}
95789604

95799605
SDValue RISCVTargetLowering::lowerVPREDUCE(SDValue Op,

0 commit comments

Comments
 (0)