Skip to content

[RISCV] Replace RISCVISD::VP_MERGE_VL with a new node that has a separate passthru operand. #75682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 21, 2023

Conversation

topperc
Copy link
Collaborator

@topperc topperc commented Dec 16, 2023

ISD::VP_MERGE treats the false operand as the source for elements past VL. The vmerge instruction encodes 3 registers and treats the vd register as the source for the tail.

This patch adds a new ISD opcode that models the tail source explicitly. During lowering we copy the false operand to this operand.

I think we can merge RISCVISD::VSELECT_VL with this new opcode by using an UNDEF passthru, but I'll save that for another patch.

…rate passthru operand.

ISD::VP_MERGE treats the false operand as the source for elements
past VL. The vmerge instruction encodes 3 registers and treats the
vd register as the source for the tail.

This patch adds a new ISD opcode that models the tail source
explicitly. During lowering we copy the false operand to this
operand.

I think we can merge RISCVISD::VSELECT_VL with this new opcode by
using an UNDEF passthru, but I'll save that for another patch.
@llvmbot
Copy link
Member

llvmbot commented Dec 16, 2023

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

Changes

ISD::VP_MERGE treats the false operand as the source for elements past VL. The vmerge instruction encodes 3 registers and treats the vd register as the source for the tail.

This patch adds a new ISD opcode that models the tail source explicitly. During lowering we copy the false operand to this operand.

I think we can merge RISCVISD::VSELECT_VL with this new opcode by using an UNDEF passthru, but I'll save that for another patch.


Full diff: https://github.com/llvm/llvm-project/pull/75682.diff

3 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+19-7)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+2-4)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td (+69-56)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 4a8ff73ec47295..7a9cab0aeb0db4 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -5519,7 +5519,7 @@ static unsigned getRISCVVLOp(SDValue Op) {
   case ISD::VP_SELECT:
     return RISCVISD::VSELECT_VL;
   case ISD::VP_MERGE:
-    return RISCVISD::VP_MERGE_VL;
+    return RISCVISD::VMERGE_VL;
   case ISD::VP_ASHR:
     return RISCVISD::SRA_VL;
   case ISD::VP_LSHR:
@@ -5567,6 +5567,8 @@ static bool hasMergeOp(unsigned Opcode) {
     return true;
   if (Opcode >= RISCVISD::STRICT_FADD_VL && Opcode <= RISCVISD::STRICT_FDIV_VL)
     return true;
+  if (Opcode == RISCVISD::VMERGE_VL)
+    return true;
   return false;
 }
 
@@ -8229,8 +8231,8 @@ static SDValue lowerVectorIntrinsicScalars(SDValue Op, SelectionDAG &DAG,
                          AVL);
     // TUMA or TUMU: Currently we always emit tumu policy regardless of tuma.
     // It's fine because vmerge does not care mask policy.
-    return DAG.getNode(RISCVISD::VP_MERGE_VL, DL, VT, Mask, Vec, MaskedOff,
-                       AVL);
+    return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, Mask, Vec,
+                       MaskedOff, MaskedOff, AVL);
   }
   }
 
@@ -10303,9 +10305,19 @@ SDValue RISCVTargetLowering::lowerVPOp(SDValue Op, SelectionDAG &DAG) const {
   for (const auto &OpIdx : enumerate(Op->ops())) {
     SDValue V = OpIdx.value();
     assert(!isa<VTSDNode>(V) && "Unexpected VTSDNode node!");
-    // Add dummy merge value before the mask.
-    if (HasMergeOp && *ISD::getVPMaskIdx(Op.getOpcode()) == OpIdx.index())
-      Ops.push_back(DAG.getUNDEF(ContainerVT));
+    // Add dummy merge value before the mask. Or if there isn't a mask, before
+    // EVL.
+    if (HasMergeOp) {
+      auto MaskIdx = ISD::getVPMaskIdx(Op.getOpcode());
+      if (MaskIdx) {
+        if (*MaskIdx == OpIdx.index())
+          Ops.push_back(DAG.getUNDEF(ContainerVT));
+      } else if (ISD::getVPExplicitVectorLengthIdx(Op.getOpcode()) == OpIdx.index()) {
+        // For VP_MERGE, copy the false operand instead of an undef value.
+        assert(Op.getOpcode() == ISD::VP_MERGE);
+        Ops.push_back(Ops.back());
+      }
+    }
     // Pass through operands which aren't fixed-length vectors.
     if (!V.getValueType().isFixedLengthVector()) {
       Ops.push_back(V);
@@ -18561,7 +18573,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(VNSRL_VL)
   NODE_NAME_CASE(SETCC_VL)
   NODE_NAME_CASE(VSELECT_VL)
-  NODE_NAME_CASE(VP_MERGE_VL)
+  NODE_NAME_CASE(VMERGE_VL)
   NODE_NAME_CASE(VMAND_VL)
   NODE_NAME_CASE(VMOR_VL)
   NODE_NAME_CASE(VMXOR_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 41a2dc5771c82d..765c6d3fb3b7c6 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -332,10 +332,8 @@ enum NodeType : unsigned {
 
   // Vector select with an additional VL operand. This operation is unmasked.
   VSELECT_VL,
-  // Vector select with operand #2 (the value when the condition is false) tied
-  // to the destination and an additional VL operand. This operation is
-  // unmasked.
-  VP_MERGE_VL,
+  // General vmerge node with mask, true, false, passthru, and vl operands.
+  VMERGE_VL,
 
   // Mask binary operators.
   VMAND_VL,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index dc6b57fad32105..33bdc3366aa3e3 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -344,7 +344,14 @@ def SDT_RISCVSelect_VL  : SDTypeProfile<1, 4, [
 ]>;
 
 def riscv_vselect_vl  : SDNode<"RISCVISD::VSELECT_VL", SDT_RISCVSelect_VL>;
-def riscv_vp_merge_vl : SDNode<"RISCVISD::VP_MERGE_VL", SDT_RISCVSelect_VL>;
+
+def SDT_RISCVVMERGE_VL  : SDTypeProfile<1, 5, [
+  SDTCisVec<0>, SDTCisVec<1>, SDTCisSameNumEltsAs<0, 1>, SDTCVecEltisVT<1, i1>,
+  SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisSameAs<0, 4>,
+  SDTCisVT<5, XLenVT>
+]>;
+
+def riscv_vmerge_vl : SDNode<"RISCVISD::VMERGE_VL", SDT_RISCVVMERGE_VL>;
 
 def SDT_RISCVVMSETCLR_VL : SDTypeProfile<1, 1, [SDTCVecEltisVT<0, i1>,
                                                 SDTCisVT<1, XLenVT>]>;
@@ -675,14 +682,14 @@ multiclass VPatTiedBinaryNoMaskVL_V<SDNode vop,
                      op2_reg_class:$rs2,
                      GPR:$vl, sew, TAIL_AGNOSTIC)>;
   // Tail undisturbed
-  def : Pat<(riscv_vp_merge_vl true_mask,
+  def : Pat<(riscv_vmerge_vl true_mask,
              (result_type (vop
                            result_reg_class:$rs1,
                            (op2_type op2_reg_class:$rs2),
                            srcvalue,
                            true_mask,
                            VLOpFrag)),
-             result_reg_class:$rs1, VLOpFrag),
+             result_reg_class:$rs1, result_reg_class:$rs1, VLOpFrag),
             (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_TIED")
                      result_reg_class:$rs1,
                      op2_reg_class:$rs2,
@@ -712,14 +719,14 @@ multiclass VPatTiedBinaryNoMaskVL_V_RM<SDNode vop,
                      FRM_DYN,
                      GPR:$vl, sew, TAIL_AGNOSTIC)>;
   // Tail undisturbed
-  def : Pat<(riscv_vp_merge_vl true_mask,
+  def : Pat<(riscv_vmerge_vl true_mask,
              (result_type (vop
                            result_reg_class:$rs1,
                            (op2_type op2_reg_class:$rs2),
                            srcvalue,
                            true_mask,
                            VLOpFrag)),
-             result_reg_class:$rs1, VLOpFrag),
+             result_reg_class:$rs1, result_reg_class:$rs1, VLOpFrag),
             (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_TIED")
                      result_reg_class:$rs1,
                      op2_reg_class:$rs2,
@@ -1697,21 +1704,21 @@ multiclass VPatMultiplyAccVL_VV_VX<PatFrag op, string instruction_name> {
   foreach vti = AllIntegerVectors in {
   defvar suffix = vti.LMul.MX;
   let Predicates = GetVTypePredicates<vti>.Predicates in {
-    def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
+    def : Pat<(riscv_vmerge_vl (vti.Mask V0),
                 (vti.Vector (op vti.RegClass:$rd,
                                 (riscv_mul_vl_oneuse vti.RegClass:$rs1, vti.RegClass:$rs2,
                                     srcvalue, (vti.Mask true_mask), VLOpFrag),
                                 srcvalue, (vti.Mask true_mask), VLOpFrag)),
-                            vti.RegClass:$rd, VLOpFrag),
+                            vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
               (!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK")
                    vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
                    (vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>;
-    def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
+    def : Pat<(riscv_vmerge_vl (vti.Mask V0),
                 (vti.Vector (op vti.RegClass:$rd,
                                 (riscv_mul_vl_oneuse (SplatPat XLenVT:$rs1), vti.RegClass:$rs2,
                                     srcvalue, (vti.Mask true_mask), VLOpFrag),
                                 srcvalue, (vti.Mask true_mask), VLOpFrag)),
-                            vti.RegClass:$rd, VLOpFrag),
+                            vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
               (!cast<Instruction>(instruction_name#"_VX_"# suffix #"_MASK")
                    vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
                    (vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>;
@@ -1840,17 +1847,17 @@ multiclass VPatFPMulAccVL_VV_VF<PatFrag vop, string instruction_name> {
   foreach vti = AllFloatVectors in {
   defvar suffix = vti.LMul.MX;
   let Predicates = GetVTypePredicates<vti>.Predicates in {
-    def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
+    def : Pat<(riscv_vmerge_vl (vti.Mask V0),
                            (vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rs2,
                             vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
-                            vti.RegClass:$rd, VLOpFrag),
+                            vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
               (!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK")
                    vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
                    (vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>;
-    def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
+    def : Pat<(riscv_vmerge_vl (vti.Mask V0),
                            (vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1), vti.RegClass:$rs2,
                             vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
-                            vti.RegClass:$rd, VLOpFrag),
+                            vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
               (!cast<Instruction>(instruction_name#"_V" # vti.ScalarSuffix # "_" # suffix # "_MASK")
                    vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
                    (vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>;
@@ -1876,10 +1883,10 @@ multiclass VPatFPMulAccVL_VV_VF_RM<PatFrag vop, string instruction_name> {
   foreach vti = AllFloatVectors in {
   defvar suffix = vti.LMul.MX;
   let Predicates = GetVTypePredicates<vti>.Predicates in {
-    def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
+    def : Pat<(riscv_vmerge_vl (vti.Mask V0),
                            (vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rs2,
                             vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
-                            vti.RegClass:$rd, VLOpFrag),
+                            vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
               (!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK")
                    vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
                    (vti.Mask V0),
@@ -1887,10 +1894,10 @@ multiclass VPatFPMulAccVL_VV_VF_RM<PatFrag vop, string instruction_name> {
                    // RISCVInsertReadWriteCSR
                    FRM_DYN,
                    GPR:$vl, vti.Log2SEW, TU_MU)>;
-    def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
+    def : Pat<(riscv_vmerge_vl (vti.Mask V0),
                            (vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1), vti.RegClass:$rs2,
                             vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
-                            vti.RegClass:$rd, VLOpFrag),
+                            vti.RegClass:$rd, vti.RegClass:$rd, VLOpFrag),
               (!cast<Instruction>(instruction_name#"_V" # vti.ScalarSuffix # "_" # suffix # "_MASK")
                    vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
                    (vti.Mask V0),
@@ -2273,29 +2280,32 @@ foreach vti = AllIntegerVectors in {
                    (vti.Vector (IMPLICIT_DEF)),
                    vti.RegClass:$rs2, simm5:$rs1, (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
 
-    def : Pat<(vti.Vector (riscv_vp_merge_vl (vti.Mask V0),
-                                             vti.RegClass:$rs1,
-                                             vti.RegClass:$rs2,
-                                             VLOpFrag)),
+    def : Pat<(vti.Vector (riscv_vmerge_vl (vti.Mask V0),
+                                           vti.RegClass:$rs1,
+                                           vti.RegClass:$rs2,
+                                           vti.RegClass:$merge,
+                                           VLOpFrag)),
               (!cast<Instruction>("PseudoVMERGE_VVM_"#vti.LMul.MX)
-                   vti.RegClass:$rs2, vti.RegClass:$rs2, vti.RegClass:$rs1,
-                   (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
+                  vti.RegClass:$merge, vti.RegClass:$rs2, vti.RegClass:$rs1,
+                  (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
 
-    def : Pat<(vti.Vector (riscv_vp_merge_vl (vti.Mask V0),
-                                             (SplatPat XLenVT:$rs1),
-                                             vti.RegClass:$rs2,
-                                             VLOpFrag)),
+    def : Pat<(vti.Vector (riscv_vmerge_vl (vti.Mask V0),
+                                            (SplatPat XLenVT:$rs1),
+                                            vti.RegClass:$rs2,
+                                            vti.RegClass:$merge,
+                                            VLOpFrag)),
               (!cast<Instruction>("PseudoVMERGE_VXM_"#vti.LMul.MX)
-                   vti.RegClass:$rs2, vti.RegClass:$rs2, GPR:$rs1,
-                   (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
-
-    def : Pat<(vti.Vector (riscv_vp_merge_vl (vti.Mask V0),
-                                             (SplatPat_simm5 simm5:$rs1),
-                                             vti.RegClass:$rs2,
-                                             VLOpFrag)),
+                  vti.RegClass:$merge, vti.RegClass:$rs2, GPR:$rs1,
+                  (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
+
+    def : Pat<(vti.Vector (riscv_vmerge_vl (vti.Mask V0),
+                                           (SplatPat_simm5 simm5:$rs1),
+                                           vti.RegClass:$rs2,
+                                           vti.RegClass:$merge,
+                                           VLOpFrag)),
               (!cast<Instruction>("PseudoVMERGE_VIM_"#vti.LMul.MX)
-                   vti.RegClass:$rs2, vti.RegClass:$rs2, simm5:$rs1,
-                   (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
+                  vti.RegClass:$merge, vti.RegClass:$rs2, simm5:$rs1,
+                  (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
   }
 }
 
@@ -2493,21 +2503,23 @@ foreach fvti = AllFloatVectors in {
                    (fvti.Vector (IMPLICIT_DEF)),
                    fvti.RegClass:$rs2, 0, (fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>;
 
-    def : Pat<(fvti.Vector (riscv_vp_merge_vl (fvti.Mask V0),
-                                              fvti.RegClass:$rs1,
-                                              fvti.RegClass:$rs2,
-                                              VLOpFrag)),
-              (!cast<Instruction>("PseudoVMERGE_VVM_"#fvti.LMul.MX)
-                   fvti.RegClass:$rs2, fvti.RegClass:$rs2, fvti.RegClass:$rs1, (fvti.Mask V0),
-                   GPR:$vl, fvti.Log2SEW)>;
-
-    def : Pat<(fvti.Vector (riscv_vp_merge_vl (fvti.Mask V0),
-                                              (SplatFPOp (fvti.Scalar fpimm0)),
-                                              fvti.RegClass:$rs2,
-                                              VLOpFrag)),
-              (!cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX)
-                   fvti.RegClass:$rs2, fvti.RegClass:$rs2, 0, (fvti.Mask V0),
-                   GPR:$vl, fvti.Log2SEW)>;
+  def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0),
+                                          fvti.RegClass:$rs1,
+                                          fvti.RegClass:$rs2,
+                                          fvti.RegClass:$merge,
+                                          VLOpFrag)),
+            (!cast<Instruction>("PseudoVMERGE_VVM_"#fvti.LMul.MX)
+                 fvti.RegClass:$merge, fvti.RegClass:$rs2, fvti.RegClass:$rs1, (fvti.Mask V0),
+                 GPR:$vl, fvti.Log2SEW)>;
+
+  def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0),
+                                          (SplatFPOp (fvti.Scalar fpimm0)),
+                                          fvti.RegClass:$rs2,
+                                          fvti.RegClass:$merge,
+                                          VLOpFrag)),
+            (!cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX)
+                 fvti.RegClass:$merge, fvti.RegClass:$rs2, 0, (fvti.Mask V0),
+                 GPR:$vl, fvti.Log2SEW)>;
   }
 
   let Predicates = GetVTypePredicates<fvti>.Predicates in {
@@ -2521,12 +2533,13 @@ foreach fvti = AllFloatVectors in {
                    (fvti.Scalar fvti.ScalarRegClass:$rs1),
                    (fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>;
 
-    def : Pat<(fvti.Vector (riscv_vp_merge_vl (fvti.Mask V0),
-                                              (SplatFPOp fvti.ScalarRegClass:$rs1),
-                                              fvti.RegClass:$rs2,
-                                              VLOpFrag)),
+    def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0),
+                                            (SplatFPOp fvti.ScalarRegClass:$rs1),
+                                            fvti.RegClass:$rs2,
+                                            fvti.RegClass:$merge,
+                                            VLOpFrag)),
               (!cast<Instruction>("PseudoVFMERGE_V"#fvti.ScalarSuffix#"M_"#fvti.LMul.MX)
-                   fvti.RegClass:$rs2, fvti.RegClass:$rs2,
+                   fvti.RegClass:$merge, fvti.RegClass:$rs2,
                    (fvti.Scalar fvti.ScalarRegClass:$rs1),
                    (fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>;
 

Copy link

github-actions bot commented Dec 16, 2023

:white_check_mark: With the latest revision this PR passed the C/C++ code formatter.

Copy link
Collaborator

@preames preames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@topperc topperc merged commit e64f5d6 into llvm:main Dec 21, 2023
@topperc topperc deleted the pr/vmerge branch December 21, 2023 22:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants