Skip to content

Commit 3e79847

Browse files
authored
[LegalizeDAG][RISCV] Don't promote f16 vector ISD::FNEG/FABS/FCOPYSIGN to f32 when we don't have Zvfh. (#106652)
The fp_extend will canonicalize NaNs which is not the semantics of FNEG/FABS/FCOPYSIGN. For fixed vectors I'm scalarizing due to test changes on other targets where the scalarization is expected. I will try to address in a follow up. For scalable vectors, we bitcast to integer and use integer logic ops.
1 parent c94bd96 commit 3e79847

File tree

9 files changed

+3691
-1482
lines changed

9 files changed

+3691
-1482
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ class VectorLegalizer {
142142
std::pair<SDValue, SDValue> ExpandLoad(SDNode *N);
143143
SDValue ExpandStore(SDNode *N);
144144
SDValue ExpandFNEG(SDNode *Node);
145+
SDValue ExpandFABS(SDNode *Node);
146+
SDValue ExpandFCOPYSIGN(SDNode *Node);
145147
void ExpandFSUB(SDNode *Node, SmallVectorImpl<SDValue> &Results);
146148
void ExpandSETCC(SDNode *Node, SmallVectorImpl<SDValue> &Results);
147149
void ExpandBITREVERSE(SDNode *Node, SmallVectorImpl<SDValue> &Results);
@@ -942,6 +944,18 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
942944
return;
943945
}
944946
break;
947+
case ISD::FABS:
948+
if (SDValue Expanded = ExpandFABS(Node)) {
949+
Results.push_back(Expanded);
950+
return;
951+
}
952+
break;
953+
case ISD::FCOPYSIGN:
954+
if (SDValue Expanded = ExpandFCOPYSIGN(Node)) {
955+
Results.push_back(Expanded);
956+
return;
957+
}
958+
break;
945959
case ISD::FSUB:
946960
ExpandFSUB(Node, Results);
947961
return;
@@ -1781,7 +1795,7 @@ SDValue VectorLegalizer::ExpandFNEG(SDNode *Node) {
17811795

17821796
// FIXME: The FSUB check is here to force unrolling v1f64 vectors on AArch64.
17831797
if (!TLI.isOperationLegalOrCustom(ISD::XOR, IntVT) ||
1784-
!TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
1798+
!(TLI.isOperationLegalOrCustom(ISD::FSUB, VT) || VT.isScalableVector()))
17851799
return SDValue();
17861800

17871801
SDLoc DL(Node);
@@ -1792,6 +1806,53 @@ SDValue VectorLegalizer::ExpandFNEG(SDNode *Node) {
17921806
return DAG.getNode(ISD::BITCAST, DL, VT, Xor);
17931807
}
17941808

1809+
SDValue VectorLegalizer::ExpandFABS(SDNode *Node) {
1810+
EVT VT = Node->getValueType(0);
1811+
EVT IntVT = VT.changeVectorElementTypeToInteger();
1812+
1813+
// FIXME: We shouldn't restrict this to scalable vectors.
1814+
if (!TLI.isOperationLegalOrCustom(ISD::AND, IntVT) || !VT.isScalableVector())
1815+
return SDValue();
1816+
1817+
SDLoc DL(Node);
1818+
SDValue Cast = DAG.getNode(ISD::BITCAST, DL, IntVT, Node->getOperand(0));
1819+
SDValue ClearSignMask = DAG.getConstant(
1820+
APInt::getSignedMaxValue(IntVT.getScalarSizeInBits()), DL, IntVT);
1821+
SDValue ClearedSign = DAG.getNode(ISD::AND, DL, IntVT, Cast, ClearSignMask);
1822+
return DAG.getNode(ISD::BITCAST, DL, VT, ClearedSign);
1823+
}
1824+
1825+
SDValue VectorLegalizer::ExpandFCOPYSIGN(SDNode *Node) {
1826+
EVT VT = Node->getValueType(0);
1827+
EVT IntVT = VT.changeVectorElementTypeToInteger();
1828+
1829+
// FIXME: We shouldn't restrict this to scalable vectors.
1830+
if (VT != Node->getOperand(1).getValueType() ||
1831+
!TLI.isOperationLegalOrCustom(ISD::AND, IntVT) ||
1832+
!TLI.isOperationLegalOrCustom(ISD::OR, IntVT) || !VT.isScalableVector())
1833+
return SDValue();
1834+
1835+
SDLoc DL(Node);
1836+
SDValue Mag = DAG.getNode(ISD::BITCAST, DL, IntVT, Node->getOperand(0));
1837+
SDValue Sign = DAG.getNode(ISD::BITCAST, DL, IntVT, Node->getOperand(1));
1838+
1839+
SDValue SignMask = DAG.getConstant(
1840+
APInt::getSignMask(IntVT.getScalarSizeInBits()), DL, IntVT);
1841+
SDValue SignBit = DAG.getNode(ISD::AND, DL, IntVT, Sign, SignMask);
1842+
1843+
SDValue ClearSignMask = DAG.getConstant(
1844+
APInt::getSignedMaxValue(IntVT.getScalarSizeInBits()), DL, IntVT);
1845+
SDValue ClearedSign = DAG.getNode(ISD::AND, DL, IntVT, Mag, ClearSignMask);
1846+
1847+
SDNodeFlags Flags;
1848+
Flags.setDisjoint(true);
1849+
1850+
SDValue CopiedSign =
1851+
DAG.getNode(ISD::OR, DL, IntVT, ClearedSign, SignBit, Flags);
1852+
1853+
return DAG.getNode(ISD::BITCAST, DL, VT, CopiedSign);
1854+
}
1855+
17951856
void VectorLegalizer::ExpandFSUB(SDNode *Node,
17961857
SmallVectorImpl<SDValue> &Results) {
17971858
// For floating-point values, (a-b) is the same as a+(-b). If FNEG is legal,

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -934,13 +934,12 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
934934

935935
// TODO: support more ops.
936936
static const unsigned ZvfhminPromoteOps[] = {
937-
ISD::FMINNUM, ISD::FMAXNUM, ISD::FADD, ISD::FSUB,
938-
ISD::FMUL, ISD::FMA, ISD::FDIV, ISD::FSQRT,
939-
ISD::FABS, ISD::FNEG, ISD::FCOPYSIGN, ISD::FCEIL,
940-
ISD::FFLOOR, ISD::FROUND, ISD::FROUNDEVEN, ISD::FRINT,
941-
ISD::FNEARBYINT, ISD::IS_FPCLASS, ISD::SETCC, ISD::FMAXIMUM,
942-
ISD::FMINIMUM, ISD::STRICT_FADD, ISD::STRICT_FSUB, ISD::STRICT_FMUL,
943-
ISD::STRICT_FDIV, ISD::STRICT_FSQRT, ISD::STRICT_FMA};
937+
ISD::FMINNUM, ISD::FMAXNUM, ISD::FADD, ISD::FSUB,
938+
ISD::FMUL, ISD::FMA, ISD::FDIV, ISD::FSQRT,
939+
ISD::FCEIL, ISD::FFLOOR, ISD::FROUND, ISD::FROUNDEVEN,
940+
ISD::FRINT, ISD::FNEARBYINT, ISD::IS_FPCLASS, ISD::SETCC,
941+
ISD::FMAXIMUM, ISD::FMINIMUM, ISD::STRICT_FADD, ISD::STRICT_FSUB,
942+
ISD::STRICT_FMUL, ISD::STRICT_FDIV, ISD::STRICT_FSQRT, ISD::STRICT_FMA};
944943

945944
// TODO: support more vp ops.
946945
static const unsigned ZvfhminPromoteVPOps[] = {ISD::VP_FADD,
@@ -1082,6 +1081,10 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
10821081
// load/store
10831082
setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
10841083

1084+
setOperationAction(ISD::FNEG, VT, Expand);
1085+
setOperationAction(ISD::FABS, VT, Expand);
1086+
setOperationAction(ISD::FCOPYSIGN, VT, Expand);
1087+
10851088
// Custom split nxv32f16 since nxv32f32 if not legal.
10861089
if (VT == MVT::nxv32f16) {
10871090
setOperationAction(ZvfhminPromoteOps, VT, Custom);
@@ -1337,6 +1340,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
13371340
// available.
13381341
setOperationAction(ISD::BUILD_VECTOR, MVT::f16, Custom);
13391342
}
1343+
setOperationAction(ISD::FNEG, VT, Expand);
1344+
setOperationAction(ISD::FABS, VT, Expand);
1345+
setOperationAction(ISD::FCOPYSIGN, VT, Expand);
13401346
MVT F32VecVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
13411347
// Don't promote f16 vector operations to f32 if f32 vector type is
13421348
// not legal.

0 commit comments

Comments
 (0)