Skip to content

Commit f2a05c6

Browse files
committed
[RISCV] Add RISCVISD nodes for VWFMADD_VL.
Use it to replace isel patterns with a DAG combine of FP_EXTEND_VL+VFMADD_VL. This makes it similar to how other widening operations are handled. I plan to use this to make it easier to form tail undisturbed vfwmacc.
1 parent 6e6bed5 commit f2a05c6

File tree

3 files changed

+89
-21
lines changed

3 files changed

+89
-21
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11196,7 +11196,7 @@ static unsigned negateFMAOpcode(unsigned Opcode, bool NegMul, bool NegAcc) {
1119611196
return Opcode;
1119711197
}
1119811198

11199-
static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG) {
11199+
static SDValue combineVFMADD_VLWithVFNEG_VL(SDNode *N, SelectionDAG &DAG) {
1120011200
// Fold FNEG_VL into FMA opcodes.
1120111201
// The first operand of strict-fp is chain.
1120211202
unsigned Offset = N->isTargetStrictFPOpcode();
@@ -11233,6 +11233,59 @@ static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG) {
1123311233
VL);
1123411234
}
1123511235

11236+
static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG) {
11237+
if (SDValue V = combineVFMADD_VLWithVFNEG_VL(N, DAG))
11238+
return V;
11239+
11240+
// FIXME: Ignore strict opcodes for now.
11241+
if (N->isTargetStrictFPOpcode())
11242+
return SDValue();
11243+
11244+
// Try to form widening FMA.
11245+
SDValue Op0 = N->getOperand(0);
11246+
SDValue Op1 = N->getOperand(1);
11247+
SDValue Mask = N->getOperand(3);
11248+
SDValue VL = N->getOperand(4);
11249+
11250+
if (Op0.getOpcode() != RISCVISD::FP_EXTEND_VL ||
11251+
Op1.getOpcode() != RISCVISD::FP_EXTEND_VL)
11252+
return SDValue();
11253+
11254+
// TODO: Refactor to handle more complex cases similar to
11255+
// combineBinOp_VLToVWBinOp_VL.
11256+
if (!Op0.hasOneUse() || !Op1.hasOneUse())
11257+
return SDValue();
11258+
11259+
// Check the mask and VL are the same.
11260+
if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL ||
11261+
Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL)
11262+
return SDValue();
11263+
11264+
unsigned NewOpc;
11265+
switch (N->getOpcode()) {
11266+
default:
11267+
llvm_unreachable("Unexpected opcode");
11268+
case RISCVISD::VFMADD_VL:
11269+
NewOpc = RISCVISD::VFWMADD_VL;
11270+
break;
11271+
case RISCVISD::VFNMSUB_VL:
11272+
NewOpc = RISCVISD::VFWNMSUB_VL;
11273+
break;
11274+
case RISCVISD::VFNMADD_VL:
11275+
NewOpc = RISCVISD::VFWNMADD_VL;
11276+
break;
11277+
case RISCVISD::VFMSUB_VL:
11278+
NewOpc = RISCVISD::VFWMSUB_VL;
11279+
break;
11280+
}
11281+
11282+
Op0 = Op0.getOperand(0);
11283+
Op1 = Op1.getOperand(0);
11284+
11285+
return DAG.getNode(NewOpc, SDLoc(N), N->getValueType(0), Op0, Op1,
11286+
N->getOperand(2), Mask, VL);
11287+
}
11288+
1123611289
static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
1123711290
const RISCVSubtarget &Subtarget) {
1123811291
assert(N->getOpcode() == ISD::SRA && "Unexpected opcode");
@@ -15074,6 +15127,10 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
1507415127
NODE_NAME_CASE(VFNMADD_VL)
1507515128
NODE_NAME_CASE(VFMSUB_VL)
1507615129
NODE_NAME_CASE(VFNMSUB_VL)
15130+
NODE_NAME_CASE(VFWMADD_VL)
15131+
NODE_NAME_CASE(VFWNMADD_VL)
15132+
NODE_NAME_CASE(VFWMSUB_VL)
15133+
NODE_NAME_CASE(VFWNMSUB_VL)
1507715134
NODE_NAME_CASE(FCOPYSIGN_VL)
1507815135
NODE_NAME_CASE(SMIN_VL)
1507915136
NODE_NAME_CASE(SMAX_VL)

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,13 @@ enum NodeType : unsigned {
256256
VFMSUB_VL,
257257
VFNMSUB_VL,
258258

259+
// Vector widening FMA ops with a mask as a fourth operand and VL as a fifth
260+
// operand.
261+
VFWMADD_VL,
262+
VFWNMADD_VL,
263+
VFWMSUB_VL,
264+
VFWNMSUB_VL,
265+
259266
// Widening instructions with a merge value a third operand, a mask as a
260267
// fourth operand, and VL as a fifth operand.
261268
VWMUL_VL,

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,25 @@ def SDT_RISCVVecFMA_VL : SDTypeProfile<1, 5, [SDTCisSameAs<0, 1>,
136136
SDTCVecEltisVT<4, i1>,
137137
SDTCisSameNumEltsAs<0, 4>,
138138
SDTCisVT<5, XLenVT>]>;
139-
def riscv_vfmadd_vl : SDNode<"RISCVISD::VFMADD_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative]>;
139+
def riscv_vfmadd_vl : SDNode<"RISCVISD::VFMADD_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative]>;
140140
def riscv_vfnmadd_vl : SDNode<"RISCVISD::VFNMADD_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative]>;
141-
def riscv_vfmsub_vl : SDNode<"RISCVISD::VFMSUB_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative]>;
141+
def riscv_vfmsub_vl : SDNode<"RISCVISD::VFMSUB_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative]>;
142142
def riscv_vfnmsub_vl : SDNode<"RISCVISD::VFNMSUB_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative]>;
143143

144+
def SDT_RISCVWVecFMA_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisFP<0>,
145+
SDTCisVec<1>, SDTCisFP<1>,
146+
SDTCisOpSmallerThanOp<1, 0>,
147+
SDTCisSameNumEltsAs<0, 1>,
148+
SDTCisSameAs<1, 2>,
149+
SDTCisSameAs<0, 3>,
150+
SDTCVecEltisVT<4, i1>,
151+
SDTCisSameNumEltsAs<0, 4>,
152+
SDTCisVT<5, XLenVT>]>;
153+
def riscv_vfwmadd_vl : SDNode<"RISCVISD::VFWMADD_VL", SDT_RISCVWVecFMA_VL, [SDNPCommutative]>;
154+
def riscv_vfwnmadd_vl : SDNode<"RISCVISD::VFWNMADD_VL", SDT_RISCVWVecFMA_VL, [SDNPCommutative]>;
155+
def riscv_vfwmsub_vl : SDNode<"RISCVISD::VFWMSUB_VL", SDT_RISCVWVecFMA_VL, [SDNPCommutative]>;
156+
def riscv_vfwnmsub_vl : SDNode<"RISCVISD::VFWNMSUB_VL", SDT_RISCVWVecFMA_VL, [SDNPCommutative]>;
157+
144158
def riscv_strict_vfmadd_vl : SDNode<"RISCVISD::STRICT_VFMADD_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative, SDNPHasChain]>;
145159
def riscv_strict_vfnmadd_vl : SDNode<"RISCVISD::STRICT_VFNMADD_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative, SDNPHasChain]>;
146160
def riscv_strict_vfmsub_vl : SDNode<"RISCVISD::STRICT_VFMSUB_VL", SDT_RISCVVecFMA_VL, [SDNPCommutative, SDNPHasChain]>;
@@ -1514,25 +1528,15 @@ multiclass VPatWidenFPMulAccVL_VV_VF<SDNode vop, string instruction_name> {
15141528
foreach vtiToWti = AllWidenableFloatVectors in {
15151529
defvar vti = vtiToWti.Vti;
15161530
defvar wti = vtiToWti.Wti;
1517-
def : Pat<(vop
1518-
(wti.Vector (riscv_fpextend_vl_oneuse
1519-
(vti.Vector vti.RegClass:$rs1),
1520-
(vti.Mask true_mask), VLOpFrag)),
1521-
(wti.Vector (riscv_fpextend_vl_oneuse
1522-
(vti.Vector vti.RegClass:$rs2),
1523-
(vti.Mask true_mask), VLOpFrag)),
1531+
def : Pat<(vop (vti.Vector vti.RegClass:$rs1),
1532+
(vti.Vector vti.RegClass:$rs2),
15241533
(wti.Vector wti.RegClass:$rd), (vti.Mask true_mask),
15251534
VLOpFrag),
15261535
(!cast<Instruction>(instruction_name#"_VV_"#vti.LMul.MX)
15271536
wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
15281537
GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
1529-
def : Pat<(vop
1530-
(wti.Vector (riscv_fpextend_vl_oneuse
1531-
(vti.Vector (SplatFPOp vti.ScalarRegClass:$rs1)),
1532-
(vti.Mask true_mask), VLOpFrag)),
1533-
(wti.Vector (riscv_fpextend_vl_oneuse
1534-
(vti.Vector vti.RegClass:$rs2),
1535-
(vti.Mask true_mask), VLOpFrag)),
1538+
def : Pat<(vop (vti.Vector (SplatFPOp vti.ScalarRegClass:$rs1)),
1539+
(vti.Vector vti.RegClass:$rs2),
15361540
(wti.Vector wti.RegClass:$rd), (vti.Mask true_mask),
15371541
VLOpFrag),
15381542
(!cast<Instruction>(instruction_name#"_V"#vti.ScalarSuffix#"_"#vti.LMul.MX)
@@ -1827,10 +1831,10 @@ defm : VPatFPMulAccVL_VV_VF<riscv_vfnmadd_vl_oneuse, "PseudoVFNMACC">;
18271831
defm : VPatFPMulAccVL_VV_VF<riscv_vfnmsub_vl_oneuse, "PseudoVFNMSAC">;
18281832

18291833
// 13.7. Vector Widening Floating-Point Fused Multiply-Add Instructions
1830-
defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfmadd_vl, "PseudoVFWMACC">;
1831-
defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfnmadd_vl, "PseudoVFWNMACC">;
1832-
defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfmsub_vl, "PseudoVFWMSAC">;
1833-
defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfnmsub_vl, "PseudoVFWNMSAC">;
1834+
defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfwmadd_vl, "PseudoVFWMACC">;
1835+
defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfwnmadd_vl, "PseudoVFWNMACC">;
1836+
defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfwmsub_vl, "PseudoVFWMSAC">;
1837+
defm : VPatWidenFPMulAccVL_VV_VF<riscv_vfwnmsub_vl, "PseudoVFWNMSAC">;
18341838

18351839
// 13.11. Vector Floating-Point MIN/MAX Instructions
18361840
defm : VPatBinaryFPVL_VV_VF<riscv_fminnum_vl, "PseudoVFMIN">;

0 commit comments

Comments
 (0)