-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[RISCV] Replace RISCVISD::VP_MERGE_VL with a new node that has a separate passthru operand. #75682
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…rate passthru operand. ISD::VP_MERGE treats the false operand as the source for elements past VL. The vmerge instruction encodes 3 registers and treats the vd register as the source for the tail. This patch adds a new ISD opcode that models the tail source explicitly. During lowering we copy the false operand to this operand. I think we can merge RISCVISD::VSELECT_VL with this new opcode by using an UNDEF passthru, but I'll save that for another patch.
@llvm/pr-subscribers-backend-risc-v Author: Craig Topper (topperc) ChangesISD::VP_MERGE treats the false operand as the source for elements past VL. The vmerge instruction encodes 3 registers and treats the vd register as the source for the tail. This patch adds a new ISD opcode that models the tail source explicitly. During lowering we copy the false operand to this operand. I think we can merge RISCVISD::VSELECT_VL with this new opcode by using an UNDEF passthru, but I'll save that for another patch. Full diff: https://github.com/llvm/llvm-project/pull/75682.diff 3 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 4a8ff73ec47295..7a9cab0aeb0db4 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -5519,7 +5519,7 @@ static unsigned getRISCVVLOp(SDValue Op) {
case ISD::VP_SELECT:
return RISCVISD::VSELECT_VL;
case ISD::VP_MERGE:
- return RISCVISD::VP_MERGE_VL;
+ return RISCVISD::VMERGE_VL;
case ISD::VP_ASHR:
return RISCVISD::SRA_VL;
case ISD::VP_LSHR:
@@ -5567,6 +5567,8 @@ static bool hasMergeOp(unsigned Opcode) {
return true;
if (Opcode >= RISCVISD::STRICT_FADD_VL && Opcode <= RISCVISD::STRICT_FDIV_VL)
return true;
+ if (Opcode == RISCVISD::VMERGE_VL)
+ return true;
return false;
}
@@ -8229,8 +8231,8 @@ static SDValue lowerVectorIntrinsicScalars(SDValue Op, SelectionDAG &DAG,
AVL);
// TUMA or TUMU: Currently we always emit tumu policy regardless of tuma.
// It's fine because vmerge does not care mask policy.
- return DAG.getNode(RISCVISD::VP_MERGE_VL, DL, VT, Mask, Vec, MaskedOff,
- AVL);
+ return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, Mask, Vec,
+ MaskedOff, MaskedOff, AVL);
}
}
@@ -10303,9 +10305,19 @@ SDValue RISCVTargetLowering::lowerVPOp(SDValue Op, SelectionDAG &DAG) const {
for (const auto &OpIdx : enumerate(Op->ops())) {
SDValue V = OpIdx.value();
assert(!isa<VTSDNode>(V) && "Unexpected VTSDNode node!");
- // Add dummy merge value before the mask.
- if (HasMergeOp && *ISD::getVPMaskIdx(Op.getOpcode()) == OpIdx.index())
- Ops.push_back(DAG.getUNDEF(ContainerVT));
+ // Add dummy merge value before the mask. Or if there isn't a mask, before
+ // EVL.
+ if (HasMergeOp) {
+ auto MaskIdx = ISD::getVPMaskIdx(Op.getOpcode());
+ if (MaskIdx) {
+ if (*MaskIdx == OpIdx.index())
+ Ops.push_back(DAG.getUNDEF(ContainerVT));
+ } else if (ISD::getVPExplicitVectorLengthIdx(Op.getOpcode()) == OpIdx.index()) {
+ // For VP_MERGE, copy the false operand instead of an undef value.
+ assert(Op.getOpcode() == ISD::VP_MERGE);
+ Ops.push_back(Ops.back());
+ }
+ }
// Pass through operands which aren't fixed-length vectors.
if (!V.getValueType().isFixedLengthVector()) {
Ops.push_back(V);
@@ -18561,7 +18573,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VNSRL_VL)
NODE_NAME_CASE(SETCC_VL)
NODE_NAME_CASE(VSELECT_VL)
- NODE_NAME_CASE(VP_MERGE_VL)
+ NODE_NAME_CASE(VMERGE_VL)
NODE_NAME_CASE(VMAND_VL)
NODE_NAME_CASE(VMOR_VL)
NODE_NAME_CASE(VMXOR_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 41a2dc5771c82d..765c6d3fb3b7c6 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -332,10 +332,8 @@ enum NodeType : unsigned {
// Vector select with an additional VL operand. This operation is unmasked.
VSELECT_VL,
- // Vector select with operand #2 (the value when the condition is false) tied
- // to the destination and an additional VL operand. This operation is
- // unmasked.
- VP_MERGE_VL,
+ // General vmerge node with mask, true, false, passthru, and vl operands.
+ VMERGE_VL,
// Mask binary operators.
VMAND_VL,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index dc6b57fad32105..33bdc3366aa3e3 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -344,7 +344,14 @@ def SDT_RISCVSelect_VL : SDTypeProfile<1, 4, [
]>;
def riscv_vselect_vl : SDNode<"RISCVISD::VSELECT_VL", SDT_RISCVSelect_VL>;
-def riscv_vp_merge_vl : SDNode<"RISCVISD::VP_MERGE_VL", SDT_RISCVSelect_VL>;
+
+def SDT_RISCVVMERGE_VL : SDTypeProfile<1, 5, [
+ SDTCisVec<0>, SDTCisVec<1>, SDTCisSameNumEltsAs<0, 1>, SDTCVecEltisVT<1, i1>,
+ SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisSameAs<0, 4>,
+ SDTCisVT<5, XLenVT>
+]>;
+
+def riscv_vmerge_vl : SDNode<"RISCVISD::VMERGE_VL", SDT_RISCVVMERGE_VL>;
def SDT_RISCVVMSETCLR_VL : SDTypeProfile<1, 1, [SDTCVecEltisVT<0, i1>,
SDTCisVT<1, XLenVT>]>;
@@ -675,14 +682,14 @@ multiclass VPatTiedBinaryNoMaskVL_V<SDNode vop,
op2_reg_class:$rs2,
GPR:$vl, sew, TAIL_AGNOSTIC)>;
// Tail undisturbed
- def : Pat<(riscv_vp_merge_vl true_mask,
+ def : Pat<(riscv_vmerge_vl true_mask,
(result_type (vop
result_reg_class:$rs1,
(op2_type op2_reg_class:$rs2),
srcvalue,
true_mask,
VLOpFrag)),
- result_reg_class:$rs1, VLOpFrag),
+ result_reg_class:$rs1, result_reg_class:$rs1, VLOpFrag),
(!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_TIED")
result_reg_class:$rs1,
op2_reg_class:$rs2,
@@ -712,14 +719,14 @@ multiclass VPatTiedBinaryNoMaskVL_V_RM<SDNode vop,
FRM_DYN,
GPR:$vl, sew, TAIL_AGNOSTIC)>;
// Tail undisturbed
- def : Pat<(riscv_vp_merge_vl true_mask,
+ def : Pat<(riscv_vmerge_vl true_mask,
(result_type (vop
result_reg_class:$rs1,
(op2_type op2_reg_class:$rs2),
srcvalue,
true_mask,
VLOpFrag)),
- result_reg_class:$rs1, VLOpFrag),
+ result_reg_class:$rs1, result_reg_class:$rs1, VLOpFrag),
(!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_TIED")
result_reg_class:$rs1,
op2_reg_class:$rs2,
@@ -1697,21 +1704,21 @@ multiclass VPatMultiplyAccVL_VV_VX<PatFrag op, string instruction_name> {
foreach vti = AllIntegerVectors in {
defvar suffix = vti.LMul.MX;
let Predicates = GetVTypePredicates<vti>.Predicates in {
- def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
+ def : Pat<(riscv_vmerge_vl (vti.Mask V0),
(vti.Vector (op vti.RegClass:$rd,
(riscv_mul_vl_oneuse vti.RegClass:$rs1, vti.RegClass:$rs2,
srcvalue, (vti.Mask true_mask), VLOpFrag),
srcvalue, (vti.Mask true_mask), VLOpFrag)),
- vti.RegClass:$rd, VLOpFrag),
+ vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
(!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK")
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>;
- def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
+ def : Pat<(riscv_vmerge_vl (vti.Mask V0),
(vti.Vector (op vti.RegClass:$rd,
(riscv_mul_vl_oneuse (SplatPat XLenVT:$rs1), vti.RegClass:$rs2,
srcvalue, (vti.Mask true_mask), VLOpFrag),
srcvalue, (vti.Mask true_mask), VLOpFrag)),
- vti.RegClass:$rd, VLOpFrag),
+ vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
(!cast<Instruction>(instruction_name#"_VX_"# suffix #"_MASK")
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>;
@@ -1840,17 +1847,17 @@ multiclass VPatFPMulAccVL_VV_VF<PatFrag vop, string instruction_name> {
foreach vti = AllFloatVectors in {
defvar suffix = vti.LMul.MX;
let Predicates = GetVTypePredicates<vti>.Predicates in {
- def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
+ def : Pat<(riscv_vmerge_vl (vti.Mask V0),
(vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rs2,
vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
- vti.RegClass:$rd, VLOpFrag),
+ vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
(!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK")
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>;
- def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
+ def : Pat<(riscv_vmerge_vl (vti.Mask V0),
(vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1), vti.RegClass:$rs2,
vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
- vti.RegClass:$rd, VLOpFrag),
+ vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
(!cast<Instruction>(instruction_name#"_V" # vti.ScalarSuffix # "_" # suffix # "_MASK")
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>;
@@ -1876,10 +1883,10 @@ multiclass VPatFPMulAccVL_VV_VF_RM<PatFrag vop, string instruction_name> {
foreach vti = AllFloatVectors in {
defvar suffix = vti.LMul.MX;
let Predicates = GetVTypePredicates<vti>.Predicates in {
- def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
+ def : Pat<(riscv_vmerge_vl (vti.Mask V0),
(vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rs2,
vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
- vti.RegClass:$rd, VLOpFrag),
+ vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
(!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK")
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
(vti.Mask V0),
@@ -1887,10 +1894,10 @@ multiclass VPatFPMulAccVL_VV_VF_RM<PatFrag vop, string instruction_name> {
// RISCVInsertReadWriteCSR
FRM_DYN,
GPR:$vl, vti.Log2SEW, TU_MU)>;
- def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
+ def : Pat<(riscv_vmerge_vl (vti.Mask V0),
(vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1), vti.RegClass:$rs2,
vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
- vti.RegClass:$rd, VLOpFrag),
+ vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
(!cast<Instruction>(instruction_name#"_V" # vti.ScalarSuffix # "_" # suffix # "_MASK")
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
(vti.Mask V0),
@@ -2273,29 +2280,32 @@ foreach vti = AllIntegerVectors in {
(vti.Vector (IMPLICIT_DEF)),
vti.RegClass:$rs2, simm5:$rs1, (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
- def : Pat<(vti.Vector (riscv_vp_merge_vl (vti.Mask V0),
- vti.RegClass:$rs1,
- vti.RegClass:$rs2,
- VLOpFrag)),
+ def : Pat<(vti.Vector (riscv_vmerge_vl (vti.Mask V0),
+ vti.RegClass:$rs1,
+ vti.RegClass:$rs2,
+ vti.RegClass:$merge,
+ VLOpFrag)),
(!cast<Instruction>("PseudoVMERGE_VVM_"#vti.LMul.MX)
- vti.RegClass:$rs2, vti.RegClass:$rs2, vti.RegClass:$rs1,
- (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
+ vti.RegClass:$merge, vti.RegClass:$rs2, vti.RegClass:$rs1,
+ (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
- def : Pat<(vti.Vector (riscv_vp_merge_vl (vti.Mask V0),
- (SplatPat XLenVT:$rs1),
- vti.RegClass:$rs2,
- VLOpFrag)),
+ def : Pat<(vti.Vector (riscv_vmerge_vl (vti.Mask V0),
+ (SplatPat XLenVT:$rs1),
+ vti.RegClass:$rs2,
+ vti.RegClass:$merge,
+ VLOpFrag)),
(!cast<Instruction>("PseudoVMERGE_VXM_"#vti.LMul.MX)
- vti.RegClass:$rs2, vti.RegClass:$rs2, GPR:$rs1,
- (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
-
- def : Pat<(vti.Vector (riscv_vp_merge_vl (vti.Mask V0),
- (SplatPat_simm5 simm5:$rs1),
- vti.RegClass:$rs2,
- VLOpFrag)),
+ vti.RegClass:$merge, vti.RegClass:$rs2, GPR:$rs1,
+ (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
+
+ def : Pat<(vti.Vector (riscv_vmerge_vl (vti.Mask V0),
+ (SplatPat_simm5 simm5:$rs1),
+ vti.RegClass:$rs2,
+ vti.RegClass:$merge,
+ VLOpFrag)),
(!cast<Instruction>("PseudoVMERGE_VIM_"#vti.LMul.MX)
- vti.RegClass:$rs2, vti.RegClass:$rs2, simm5:$rs1,
- (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
+ vti.RegClass:$merge, vti.RegClass:$rs2, simm5:$rs1,
+ (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
}
}
@@ -2493,21 +2503,23 @@ foreach fvti = AllFloatVectors in {
(fvti.Vector (IMPLICIT_DEF)),
fvti.RegClass:$rs2, 0, (fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>;
- def : Pat<(fvti.Vector (riscv_vp_merge_vl (fvti.Mask V0),
- fvti.RegClass:$rs1,
- fvti.RegClass:$rs2,
- VLOpFrag)),
- (!cast<Instruction>("PseudoVMERGE_VVM_"#fvti.LMul.MX)
- fvti.RegClass:$rs2, fvti.RegClass:$rs2, fvti.RegClass:$rs1, (fvti.Mask V0),
- GPR:$vl, fvti.Log2SEW)>;
-
- def : Pat<(fvti.Vector (riscv_vp_merge_vl (fvti.Mask V0),
- (SplatFPOp (fvti.Scalar fpimm0)),
- fvti.RegClass:$rs2,
- VLOpFrag)),
- (!cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX)
- fvti.RegClass:$rs2, fvti.RegClass:$rs2, 0, (fvti.Mask V0),
- GPR:$vl, fvti.Log2SEW)>;
+ def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0),
+ fvti.RegClass:$rs1,
+ fvti.RegClass:$rs2,
+ fvti.RegClass:$merge,
+ VLOpFrag)),
+ (!cast<Instruction>("PseudoVMERGE_VVM_"#fvti.LMul.MX)
+ fvti.RegClass:$merge, fvti.RegClass:$rs2, fvti.RegClass:$rs1, (fvti.Mask V0),
+ GPR:$vl, fvti.Log2SEW)>;
+
+ def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0),
+ (SplatFPOp (fvti.Scalar fpimm0)),
+ fvti.RegClass:$rs2,
+ fvti.RegClass:$merge,
+ VLOpFrag)),
+ (!cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX)
+ fvti.RegClass:$merge, fvti.RegClass:$rs2, 0, (fvti.Mask V0),
+ GPR:$vl, fvti.Log2SEW)>;
}
let Predicates = GetVTypePredicates<fvti>.Predicates in {
@@ -2521,12 +2533,13 @@ foreach fvti = AllFloatVectors in {
(fvti.Scalar fvti.ScalarRegClass:$rs1),
(fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>;
- def : Pat<(fvti.Vector (riscv_vp_merge_vl (fvti.Mask V0),
- (SplatFPOp fvti.ScalarRegClass:$rs1),
- fvti.RegClass:$rs2,
- VLOpFrag)),
+ def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0),
+ (SplatFPOp fvti.ScalarRegClass:$rs1),
+ fvti.RegClass:$rs2,
+ fvti.RegClass:$merge,
+ VLOpFrag)),
(!cast<Instruction>("PseudoVFMERGE_V"#fvti.ScalarSuffix#"M_"#fvti.LMul.MX)
- fvti.RegClass:$rs2, fvti.RegClass:$rs2,
+ fvti.RegClass:$merge, fvti.RegClass:$rs2,
(fvti.Scalar fvti.ScalarRegClass:$rs1),
(fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>;
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
ISD::VP_MERGE treats the false operand as the source for elements past VL. The vmerge instruction encodes 3 registers and treats the vd register as the source for the tail.
This patch adds a new ISD opcode that models the tail source explicitly. During lowering we copy the false operand to this operand.
I think we can merge RISCVISD::VSELECT_VL with this new opcode by using an UNDEF passthru, but I'll save that for another patch.