Skip to content

Commit e64f5d6

Browse files
authored
[RISCV] Replace RISCVISD::VP_MERGE_VL with a new node that has a separate passthru operand. (llvm#75682)
ISD::VP_MERGE treats the false operand as the source for elements past VL. The vmerge instruction encodes 3 registers and treats the vd register as the source for the tail. This patch adds a new ISD opcode that models the tail source explicitly. During lowering we copy the false operand to this operand. I think we can merge RISCVISD::VSELECT_VL with this new opcode by using an UNDEF passthru, but I'll save that for another patch.
1 parent 3ca9bcc commit e64f5d6

File tree

3 files changed

+91
-67
lines changed

3 files changed

+91
-67
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5530,7 +5530,7 @@ static unsigned getRISCVVLOp(SDValue Op) {
55305530
case ISD::VP_SELECT:
55315531
return RISCVISD::VSELECT_VL;
55325532
case ISD::VP_MERGE:
5533-
return RISCVISD::VP_MERGE_VL;
5533+
return RISCVISD::VMERGE_VL;
55345534
case ISD::VP_ASHR:
55355535
return RISCVISD::SRA_VL;
55365536
case ISD::VP_LSHR:
@@ -5578,6 +5578,8 @@ static bool hasMergeOp(unsigned Opcode) {
55785578
return true;
55795579
if (Opcode >= RISCVISD::STRICT_FADD_VL && Opcode <= RISCVISD::STRICT_FDIV_VL)
55805580
return true;
5581+
if (Opcode == RISCVISD::VMERGE_VL)
5582+
return true;
55815583
return false;
55825584
}
55835585

@@ -8242,8 +8244,8 @@ static SDValue lowerVectorIntrinsicScalars(SDValue Op, SelectionDAG &DAG,
82428244
AVL);
82438245
// TUMA or TUMU: Currently we always emit tumu policy regardless of tuma.
82448246
// It's fine because vmerge does not care mask policy.
8245-
return DAG.getNode(RISCVISD::VP_MERGE_VL, DL, VT, Mask, Vec, MaskedOff,
8246-
AVL);
8247+
return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, Mask, Vec, MaskedOff,
8248+
MaskedOff, AVL);
82478249
}
82488250
}
82498251

@@ -10316,9 +10318,20 @@ SDValue RISCVTargetLowering::lowerVPOp(SDValue Op, SelectionDAG &DAG) const {
1031610318
for (const auto &OpIdx : enumerate(Op->ops())) {
1031710319
SDValue V = OpIdx.value();
1031810320
assert(!isa<VTSDNode>(V) && "Unexpected VTSDNode node!");
10319-
// Add dummy merge value before the mask.
10320-
if (HasMergeOp && *ISD::getVPMaskIdx(Op.getOpcode()) == OpIdx.index())
10321-
Ops.push_back(DAG.getUNDEF(ContainerVT));
10321+
// Add dummy merge value before the mask. Or if there isn't a mask, before
10322+
// EVL.
10323+
if (HasMergeOp) {
10324+
auto MaskIdx = ISD::getVPMaskIdx(Op.getOpcode());
10325+
if (MaskIdx) {
10326+
if (*MaskIdx == OpIdx.index())
10327+
Ops.push_back(DAG.getUNDEF(ContainerVT));
10328+
} else if (ISD::getVPExplicitVectorLengthIdx(Op.getOpcode()) ==
10329+
OpIdx.index()) {
10330+
// For VP_MERGE, copy the false operand instead of an undef value.
10331+
assert(Op.getOpcode() == ISD::VP_MERGE);
10332+
Ops.push_back(Ops.back());
10333+
}
10334+
}
1032210335
// Pass through operands which aren't fixed-length vectors.
1032310336
if (!V.getValueType().isFixedLengthVector()) {
1032410337
Ops.push_back(V);
@@ -18658,7 +18671,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
1865818671
NODE_NAME_CASE(VNSRL_VL)
1865918672
NODE_NAME_CASE(SETCC_VL)
1866018673
NODE_NAME_CASE(VSELECT_VL)
18661-
NODE_NAME_CASE(VP_MERGE_VL)
18674+
NODE_NAME_CASE(VMERGE_VL)
1866218675
NODE_NAME_CASE(VMAND_VL)
1866318676
NODE_NAME_CASE(VMOR_VL)
1866418677
NODE_NAME_CASE(VMXOR_VL)

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -332,10 +332,8 @@ enum NodeType : unsigned {
332332

333333
// Vector select with an additional VL operand. This operation is unmasked.
334334
VSELECT_VL,
335-
// Vector select with operand #2 (the value when the condition is false) tied
336-
// to the destination and an additional VL operand. This operation is
337-
// unmasked.
338-
VP_MERGE_VL,
335+
// General vmerge node with mask, true, false, passthru, and vl operands.
336+
VMERGE_VL,
339337

340338
// Mask binary operators.
341339
VMAND_VL,

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 69 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,14 @@ def SDT_RISCVSelect_VL : SDTypeProfile<1, 4, [
344344
]>;
345345

346346
def riscv_vselect_vl : SDNode<"RISCVISD::VSELECT_VL", SDT_RISCVSelect_VL>;
347-
def riscv_vp_merge_vl : SDNode<"RISCVISD::VP_MERGE_VL", SDT_RISCVSelect_VL>;
347+
348+
def SDT_RISCVVMERGE_VL : SDTypeProfile<1, 5, [
349+
SDTCisVec<0>, SDTCisVec<1>, SDTCisSameNumEltsAs<0, 1>, SDTCVecEltisVT<1, i1>,
350+
SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisSameAs<0, 4>,
351+
SDTCisVT<5, XLenVT>
352+
]>;
353+
354+
def riscv_vmerge_vl : SDNode<"RISCVISD::VMERGE_VL", SDT_RISCVVMERGE_VL>;
348355

349356
def SDT_RISCVVMSETCLR_VL : SDTypeProfile<1, 1, [SDTCVecEltisVT<0, i1>,
350357
SDTCisVT<1, XLenVT>]>;
@@ -675,14 +682,14 @@ multiclass VPatTiedBinaryNoMaskVL_V<SDNode vop,
675682
op2_reg_class:$rs2,
676683
GPR:$vl, sew, TAIL_AGNOSTIC)>;
677684
// Tail undisturbed
678-
def : Pat<(riscv_vp_merge_vl true_mask,
685+
def : Pat<(riscv_vmerge_vl true_mask,
679686
(result_type (vop
680687
result_reg_class:$rs1,
681688
(op2_type op2_reg_class:$rs2),
682689
srcvalue,
683690
true_mask,
684691
VLOpFrag)),
685-
result_reg_class:$rs1, VLOpFrag),
692+
result_reg_class:$rs1, result_reg_class:$rs1, VLOpFrag),
686693
(!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_TIED")
687694
result_reg_class:$rs1,
688695
op2_reg_class:$rs2,
@@ -712,14 +719,14 @@ multiclass VPatTiedBinaryNoMaskVL_V_RM<SDNode vop,
712719
FRM_DYN,
713720
GPR:$vl, sew, TAIL_AGNOSTIC)>;
714721
// Tail undisturbed
715-
def : Pat<(riscv_vp_merge_vl true_mask,
722+
def : Pat<(riscv_vmerge_vl true_mask,
716723
(result_type (vop
717724
result_reg_class:$rs1,
718725
(op2_type op2_reg_class:$rs2),
719726
srcvalue,
720727
true_mask,
721728
VLOpFrag)),
722-
result_reg_class:$rs1, VLOpFrag),
729+
result_reg_class:$rs1, result_reg_class:$rs1, VLOpFrag),
723730
(!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_TIED")
724731
result_reg_class:$rs1,
725732
op2_reg_class:$rs2,
@@ -1697,21 +1704,21 @@ multiclass VPatMultiplyAccVL_VV_VX<PatFrag op, string instruction_name> {
16971704
foreach vti = AllIntegerVectors in {
16981705
defvar suffix = vti.LMul.MX;
16991706
let Predicates = GetVTypePredicates<vti>.Predicates in {
1700-
def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
1707+
def : Pat<(riscv_vmerge_vl (vti.Mask V0),
17011708
(vti.Vector (op vti.RegClass:$rd,
17021709
(riscv_mul_vl_oneuse vti.RegClass:$rs1, vti.RegClass:$rs2,
17031710
srcvalue, (vti.Mask true_mask), VLOpFrag),
17041711
srcvalue, (vti.Mask true_mask), VLOpFrag)),
1705-
vti.RegClass:$rd, VLOpFrag),
1712+
vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
17061713
(!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK")
17071714
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
17081715
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>;
1709-
def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
1716+
def : Pat<(riscv_vmerge_vl (vti.Mask V0),
17101717
(vti.Vector (op vti.RegClass:$rd,
17111718
(riscv_mul_vl_oneuse (SplatPat XLenVT:$rs1), vti.RegClass:$rs2,
17121719
srcvalue, (vti.Mask true_mask), VLOpFrag),
17131720
srcvalue, (vti.Mask true_mask), VLOpFrag)),
1714-
vti.RegClass:$rd, VLOpFrag),
1721+
vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
17151722
(!cast<Instruction>(instruction_name#"_VX_"# suffix #"_MASK")
17161723
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
17171724
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>;
@@ -1840,17 +1847,17 @@ multiclass VPatFPMulAccVL_VV_VF<PatFrag vop, string instruction_name> {
18401847
foreach vti = AllFloatVectors in {
18411848
defvar suffix = vti.LMul.MX;
18421849
let Predicates = GetVTypePredicates<vti>.Predicates in {
1843-
def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
1850+
def : Pat<(riscv_vmerge_vl (vti.Mask V0),
18441851
(vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rs2,
18451852
vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
1846-
vti.RegClass:$rd, VLOpFrag),
1853+
vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
18471854
(!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK")
18481855
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
18491856
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>;
1850-
def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
1857+
def : Pat<(riscv_vmerge_vl (vti.Mask V0),
18511858
(vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1), vti.RegClass:$rs2,
18521859
vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
1853-
vti.RegClass:$rd, VLOpFrag),
1860+
vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
18541861
(!cast<Instruction>(instruction_name#"_V" # vti.ScalarSuffix # "_" # suffix # "_MASK")
18551862
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
18561863
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>;
@@ -1876,21 +1883,21 @@ multiclass VPatFPMulAccVL_VV_VF_RM<PatFrag vop, string instruction_name> {
18761883
foreach vti = AllFloatVectors in {
18771884
defvar suffix = vti.LMul.MX;
18781885
let Predicates = GetVTypePredicates<vti>.Predicates in {
1879-
def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
1886+
def : Pat<(riscv_vmerge_vl (vti.Mask V0),
18801887
(vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rs2,
18811888
vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
1882-
vti.RegClass:$rd, VLOpFrag),
1889+
vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
18831890
(!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK")
18841891
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
18851892
(vti.Mask V0),
18861893
// Value to indicate no rounding mode change in
18871894
// RISCVInsertReadWriteCSR
18881895
FRM_DYN,
18891896
GPR:$vl, vti.Log2SEW, TU_MU)>;
1890-
def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
1897+
def : Pat<(riscv_vmerge_vl (vti.Mask V0),
18911898
(vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1), vti.RegClass:$rs2,
18921899
vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
1893-
vti.RegClass:$rd, VLOpFrag),
1900+
vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
18941901
(!cast<Instruction>(instruction_name#"_V" # vti.ScalarSuffix # "_" # suffix # "_MASK")
18951902
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
18961903
(vti.Mask V0),
@@ -2273,29 +2280,32 @@ foreach vti = AllIntegerVectors in {
22732280
(vti.Vector (IMPLICIT_DEF)),
22742281
vti.RegClass:$rs2, simm5:$rs1, (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
22752282

2276-
def : Pat<(vti.Vector (riscv_vp_merge_vl (vti.Mask V0),
2277-
vti.RegClass:$rs1,
2278-
vti.RegClass:$rs2,
2279-
VLOpFrag)),
2283+
def : Pat<(vti.Vector (riscv_vmerge_vl (vti.Mask V0),
2284+
vti.RegClass:$rs1,
2285+
vti.RegClass:$rs2,
2286+
vti.RegClass:$merge,
2287+
VLOpFrag)),
22802288
(!cast<Instruction>("PseudoVMERGE_VVM_"#vti.LMul.MX)
2281-
vti.RegClass:$rs2, vti.RegClass:$rs2, vti.RegClass:$rs1,
2282-
(vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
2289+
vti.RegClass:$merge, vti.RegClass:$rs2, vti.RegClass:$rs1,
2290+
(vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
22832291

2284-
def : Pat<(vti.Vector (riscv_vp_merge_vl (vti.Mask V0),
2285-
(SplatPat XLenVT:$rs1),
2286-
vti.RegClass:$rs2,
2287-
VLOpFrag)),
2292+
def : Pat<(vti.Vector (riscv_vmerge_vl (vti.Mask V0),
2293+
(SplatPat XLenVT:$rs1),
2294+
vti.RegClass:$rs2,
2295+
vti.RegClass:$merge,
2296+
VLOpFrag)),
22882297
(!cast<Instruction>("PseudoVMERGE_VXM_"#vti.LMul.MX)
2289-
vti.RegClass:$rs2, vti.RegClass:$rs2, GPR:$rs1,
2290-
(vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
2291-
2292-
def : Pat<(vti.Vector (riscv_vp_merge_vl (vti.Mask V0),
2293-
(SplatPat_simm5 simm5:$rs1),
2294-
vti.RegClass:$rs2,
2295-
VLOpFrag)),
2298+
vti.RegClass:$merge, vti.RegClass:$rs2, GPR:$rs1,
2299+
(vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
2300+
2301+
def : Pat<(vti.Vector (riscv_vmerge_vl (vti.Mask V0),
2302+
(SplatPat_simm5 simm5:$rs1),
2303+
vti.RegClass:$rs2,
2304+
vti.RegClass:$merge,
2305+
VLOpFrag)),
22962306
(!cast<Instruction>("PseudoVMERGE_VIM_"#vti.LMul.MX)
2297-
vti.RegClass:$rs2, vti.RegClass:$rs2, simm5:$rs1,
2298-
(vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
2307+
vti.RegClass:$merge, vti.RegClass:$rs2, simm5:$rs1,
2308+
(vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
22992309
}
23002310
}
23012311

@@ -2493,21 +2503,23 @@ foreach fvti = AllFloatVectors in {
24932503
(fvti.Vector (IMPLICIT_DEF)),
24942504
fvti.RegClass:$rs2, 0, (fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>;
24952505

2496-
def : Pat<(fvti.Vector (riscv_vp_merge_vl (fvti.Mask V0),
2497-
fvti.RegClass:$rs1,
2498-
fvti.RegClass:$rs2,
2499-
VLOpFrag)),
2500-
(!cast<Instruction>("PseudoVMERGE_VVM_"#fvti.LMul.MX)
2501-
fvti.RegClass:$rs2, fvti.RegClass:$rs2, fvti.RegClass:$rs1, (fvti.Mask V0),
2502-
GPR:$vl, fvti.Log2SEW)>;
2503-
2504-
def : Pat<(fvti.Vector (riscv_vp_merge_vl (fvti.Mask V0),
2505-
(SplatFPOp (fvti.Scalar fpimm0)),
2506-
fvti.RegClass:$rs2,
2507-
VLOpFrag)),
2508-
(!cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX)
2509-
fvti.RegClass:$rs2, fvti.RegClass:$rs2, 0, (fvti.Mask V0),
2510-
GPR:$vl, fvti.Log2SEW)>;
2506+
def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0),
2507+
fvti.RegClass:$rs1,
2508+
fvti.RegClass:$rs2,
2509+
fvti.RegClass:$merge,
2510+
VLOpFrag)),
2511+
(!cast<Instruction>("PseudoVMERGE_VVM_"#fvti.LMul.MX)
2512+
fvti.RegClass:$merge, fvti.RegClass:$rs2, fvti.RegClass:$rs1, (fvti.Mask V0),
2513+
GPR:$vl, fvti.Log2SEW)>;
2514+
2515+
def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0),
2516+
(SplatFPOp (fvti.Scalar fpimm0)),
2517+
fvti.RegClass:$rs2,
2518+
fvti.RegClass:$merge,
2519+
VLOpFrag)),
2520+
(!cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX)
2521+
fvti.RegClass:$merge, fvti.RegClass:$rs2, 0, (fvti.Mask V0),
2522+
GPR:$vl, fvti.Log2SEW)>;
25112523
}
25122524

25132525
let Predicates = GetVTypePredicates<fvti>.Predicates in {
@@ -2521,12 +2533,13 @@ foreach fvti = AllFloatVectors in {
25212533
(fvti.Scalar fvti.ScalarRegClass:$rs1),
25222534
(fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>;
25232535

2524-
def : Pat<(fvti.Vector (riscv_vp_merge_vl (fvti.Mask V0),
2525-
(SplatFPOp fvti.ScalarRegClass:$rs1),
2526-
fvti.RegClass:$rs2,
2527-
VLOpFrag)),
2536+
def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0),
2537+
(SplatFPOp fvti.ScalarRegClass:$rs1),
2538+
fvti.RegClass:$rs2,
2539+
fvti.RegClass:$merge,
2540+
VLOpFrag)),
25282541
(!cast<Instruction>("PseudoVFMERGE_V"#fvti.ScalarSuffix#"M_"#fvti.LMul.MX)
2529-
fvti.RegClass:$rs2, fvti.RegClass:$rs2,
2542+
fvti.RegClass:$merge, fvti.RegClass:$rs2,
25302543
(fvti.Scalar fvti.ScalarRegClass:$rs1),
25312544
(fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>;
25322545

0 commit comments

Comments
 (0)