Skip to content

[RISCV] Move vnclipu patterns into DAGCombiner. #93596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 29, 2024

Conversation

topperc
Copy link
Collaborator

@topperc topperc commented May 28, 2024

I plan to add support for multiple layers of vnclipu. For example,
i32->i8 using 2 vnclipu instructions. First clipping to 65535, then
clipping to 255. Similar for signed vnclip.

This scales poorly if we need to add patterns with 2 or 3 truncates.
Instead, move the code to DAGCombiner with new ISD opcodes to represent
VCLIP(U).

This patch just moves the existing patterns into DAG combine. Support
for multiple truncates will as a follow up. A similar patch series will be
made for the signed vnclip.

@llvmbot
Copy link
Member

llvmbot commented May 28, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

Changes

I plan to add support for multiple layers of vnclipu. For example,
i32->i8 using 2 vnclipu instructions. First clipping to 65535, then
clipping to 255. Similar for signed vnclip.

This scales poorly if we need to add patterns with 2 or 3 truncates.
Instead, move the code to DAGCombiner with new ISD opcodes to represent
VCLIP(U).

This patch just moves the existing patterns into DAG combine. Support
for multiple truncates will as a follow up. A similar patch series will be
made for the signed vnclip.


Full diff: https://github.com/llvm/llvm-project/pull/93596.diff

5 Files Affected:

  • (modified) llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h (+9)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+119-52)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+4)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td (-7)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td (+94-10)
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
index 08f056f78979a..550904516ac8e 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
@@ -373,6 +373,15 @@ inline static bool isValidRoundingMode(unsigned Mode) {
 }
 } // namespace RISCVFPRndMode
 
+namespace RISCVVXRndMode {
+enum RoundingMode {
+  RNU = 0,
+  RNE = 1,
+  RDN = 2,
+  ROD = 3,
+};
+} // namespace RISCVVXRndMode
+
 //===----------------------------------------------------------------------===//
 // Floating-point Immediates
 //
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f0e5a7d393b6c..1862dec7b1afd 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -5960,7 +5960,7 @@ static bool hasMergeOp(unsigned Opcode) {
          Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
          "not a RISC-V target specific op");
   static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
-                    128 &&
+                    130 &&
                 RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
                         ISD::FIRST_TARGET_STRICTFP_OPCODE ==
                     21 &&
@@ -5986,7 +5986,7 @@ static bool hasMaskOp(unsigned Opcode) {
          Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
          "not a RISC-V target specific op");
   static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
-                    128 &&
+                    130 &&
                 RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
                         ISD::FIRST_TARGET_STRICTFP_OPCODE ==
                     21 &&
@@ -16087,6 +16087,117 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask,
   return true;
 }
 
+static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
+  // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
+  // This would be benefit for the cases where X and Y are both the same value
+  // type of low precision vectors. Since the truncate would be lowered into
+  // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
+  // restriction, such pattern would be expanded into a series of "vsetvli"
+  // and "vnsrl" instructions later to reach this point.
+  auto IsTruncNode = [](SDValue V) {
+    if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
+      return false;
+    SDValue VL = V.getOperand(2);
+    auto *C = dyn_cast<ConstantSDNode>(VL);
+    // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
+    bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
+                           (isa<RegisterSDNode>(VL) &&
+                            cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
+    return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && IsVLMAXForVMSET;
+  };
+
+  SDValue Op = N->getOperand(0);
+
+  // We need to first find the inner level of TRUNCATE_VECTOR_VL node
+  // to distinguish such pattern.
+  while (IsTruncNode(Op)) {
+    if (!Op.hasOneUse())
+      return SDValue();
+    Op = Op.getOperand(0);
+  }
+
+  if (Op.getOpcode() != ISD::SRA || !Op.hasOneUse())
+    return SDValue();
+
+  SDValue N0 = Op.getOperand(0);
+  SDValue N1 = Op.getOperand(1);
+  if (N0.getOpcode() != ISD::SIGN_EXTEND || !N0.hasOneUse() ||
+      N1.getOpcode() != ISD::ZERO_EXTEND || !N1.hasOneUse())
+    return SDValue();
+
+  SDValue N00 = N0.getOperand(0);
+  SDValue N10 = N1.getOperand(0);
+  if (!N00.getValueType().isVector() ||
+      N00.getValueType() != N10.getValueType() ||
+      N->getValueType(0) != N10.getValueType())
+    return SDValue();
+
+  unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
+  SDValue SMin =
+      DAG.getNode(ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
+                  DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
+  return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
+}
+
+// Combine (truncate_vector_vl (umin X, C)) -> (vnclipu_vl X) if C is maximum
+// value for the truncated type.
+static SDValue combineTruncToVnclipu(SDNode *N, SelectionDAG &DAG,
+                                     const RISCVSubtarget &Subtarget) {
+  assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL);
+
+  MVT VT = N->getSimpleValueType(0);
+
+  SDValue Mask = N->getOperand(1);
+  SDValue VL = N->getOperand(2);
+
+  SDValue Src = N->getOperand(0);
+
+  // Src must be a UMIN or UMIN_VL.
+  if (Src.getOpcode() != ISD::UMIN &&
+      !(Src.getOpcode() == RISCVISD::UMIN_VL && Src.getOperand(2).isUndef() &&
+        Src.getOperand(3) == Mask && Src.getOperand(4) == VL))
+    return SDValue();
+
+  auto IsSplat = [&VL](SDValue Op, APInt &SplatVal) {
+    // Peek through conversion between fixed and scalable vectors.
+    if (Op.getOpcode() == ISD::INSERT_SUBVECTOR && Op.getOperand(0).isUndef() &&
+        isNullConstant(Op.getOperand(2)) &&
+        Op.getOperand(1).getValueType().isFixedLengthVector() &&
+        Op.getOperand(1).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+        Op.getOperand(1).getOperand(0).getValueType() == Op.getValueType() &&
+        isNullConstant(Op.getOperand(1).getOperand(1)))
+      Op = Op.getOperand(1).getOperand(0);
+
+    if (ISD::isConstantSplatVector(Op.getNode(), SplatVal))
+      return true;
+
+    if (Op.getOpcode() == RISCVISD::VMV_V_X_VL && Op.getOperand(0).isUndef() &&
+        Op.getOperand(2) == VL) {
+      if (auto *Op1 = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
+        SplatVal =
+            Op1->getAPIntValue().sextOrTrunc(Op.getScalarValueSizeInBits());
+        return true;
+      }
+    }
+
+    return false;
+  };
+
+  APInt C;
+  if (!IsSplat(Src.getOperand(1), C))
+    return SDValue();
+
+  if (!C.isMask(VT.getScalarSizeInBits()))
+    return SDValue();
+
+  SDLoc DL(N);
+  // Rounding mode here is arbitrary since we aren't shifting out any bits.
+  return DAG.getNode(
+      RISCVISD::VNCLIPU_VL, DL, VT,
+      {Src.getOperand(0), DAG.getConstant(0, DL, VT), DAG.getUNDEF(VT), Mask,
+       DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
+       VL});
+}
 
 SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
@@ -16304,56 +16415,10 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
       }
     }
     return SDValue();
-  case RISCVISD::TRUNCATE_VECTOR_VL: {
-    // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
-    // This would be benefit for the cases where X and Y are both the same value
-    // type of low precision vectors. Since the truncate would be lowered into
-    // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
-    // restriction, such pattern would be expanded into a series of "vsetvli"
-    // and "vnsrl" instructions later to reach this point.
-    auto IsTruncNode = [](SDValue V) {
-      if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
-        return false;
-      SDValue VL = V.getOperand(2);
-      auto *C = dyn_cast<ConstantSDNode>(VL);
-      // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
-      bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
-                             (isa<RegisterSDNode>(VL) &&
-                              cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
-      return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL &&
-             IsVLMAXForVMSET;
-    };
-
-    SDValue Op = N->getOperand(0);
-
-    // We need to first find the inner level of TRUNCATE_VECTOR_VL node
-    // to distinguish such pattern.
-    while (IsTruncNode(Op)) {
-      if (!Op.hasOneUse())
-        return SDValue();
-      Op = Op.getOperand(0);
-    }
-
-    if (Op.getOpcode() == ISD::SRA && Op.hasOneUse()) {
-      SDValue N0 = Op.getOperand(0);
-      SDValue N1 = Op.getOperand(1);
-      if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
-          N1.getOpcode() == ISD::ZERO_EXTEND && N1.hasOneUse()) {
-        SDValue N00 = N0.getOperand(0);
-        SDValue N10 = N1.getOperand(0);
-        if (N00.getValueType().isVector() &&
-            N00.getValueType() == N10.getValueType() &&
-            N->getValueType(0) == N10.getValueType()) {
-          unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
-          SDValue SMin = DAG.getNode(
-              ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
-              DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
-          return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
-        }
-      }
-    }
-    break;
-  }
+  case RISCVISD::TRUNCATE_VECTOR_VL:
+    if (SDValue V = combineTruncOfSraSext(N, DAG))
+      return V;
+    return combineTruncToVnclipu(N, DAG, Subtarget);
   case ISD::TRUNCATE:
     return performTRUNCATECombine(N, DAG, Subtarget);
   case ISD::SELECT:
@@ -19972,6 +20037,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(UADDSAT_VL)
   NODE_NAME_CASE(SSUBSAT_VL)
   NODE_NAME_CASE(USUBSAT_VL)
+  NODE_NAME_CASE(VNCLIP_VL)
+  NODE_NAME_CASE(VNCLIPU_VL)
   NODE_NAME_CASE(FADD_VL)
   NODE_NAME_CASE(FSUB_VL)
   NODE_NAME_CASE(FMUL_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 856ce06ba1c4f..3b8eb3c88901a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -273,6 +273,10 @@ enum NodeType : unsigned {
   // Rounding averaging adds of unsigned integers.
   AVGCEILU_VL,
 
+  // Operands are (source, shift, merge, mask, roundmode, vl)
+  VNCLIPU_VL,
+  VNCLIP_VL,
+
   MULHS_VL,
   MULHU_VL,
   FADD_VL,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index 66df24f2a458d..691f2052ab29d 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -1196,13 +1196,6 @@ multiclass VPatTruncSatClipSDNode<VTypeInfo vti, VTypeInfo wti> {
       (!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
         (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
         (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
-
-    def : Pat<(vti.Vector (riscv_trunc_vector_vl
-        (wti.Vector (umin (wti.Vector wti.RegClass:$rs1),
-          (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), uminval, (XLenVT srcvalue))))), (vti.Mask V0), VLOpFrag)),
-      (!cast<Instruction>("PseudoVNCLIPU_WI_"#vti.LMul.MX#"_MASK")
-        (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
-        (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
   }
 }
 
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 91f3abe22331e..610a72dd02b38 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -36,6 +36,18 @@ def SDT_RISCVIntBinOp_VL : SDTypeProfile<1, 5, [SDTCisSameAs<0, 1>,
                                                 SDTCisSameNumEltsAs<0, 4>,
                                                 SDTCisVT<5, XLenVT>]>;
 
+// Input: (vector, vector/scalar, merge, mask, roundmode, vl)
+def SDT_RISCVVNBinOp_RM_VL : SDTypeProfile<1, 6, [SDTCisVec<0>, SDTCisInt<0>,
+                                                  SDTCisSameAs<0, 3>,
+                                                  SDTCisSameNumEltsAs<0, 1>,
+                                                  SDTCisVec<1>,
+                                                  SDTCisOpSmallerThanOp<2, 1>,
+                                                  SDTCisSameAs<0, 2>,
+                                                  SDTCisSameNumEltsAs<0, 4>,
+                                                  SDTCVecEltisVT<4, i1>,
+                                                  SDTCisVT<5, XLenVT>,
+                                                  SDTCisVT<6, XLenVT>]>;
+
 def SDT_RISCVFPUnOp_VL : SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>,
                                               SDTCisVec<0>, SDTCisFP<0>,
                                               SDTCVecEltisVT<2, i1>,
@@ -120,6 +132,9 @@ def riscv_uaddsat_vl   : SDNode<"RISCVISD::UADDSAT_VL", SDT_RISCVIntBinOp_VL, [S
 def riscv_ssubsat_vl   : SDNode<"RISCVISD::SSUBSAT_VL", SDT_RISCVIntBinOp_VL>;
 def riscv_usubsat_vl   : SDNode<"RISCVISD::USUBSAT_VL", SDT_RISCVIntBinOp_VL>;
 
+def riscv_vnclipu_vl : SDNode<"RISCVISD::VNCLIPU_VL", SDT_RISCVVNBinOp_RM_VL>;
+def riscv_vnclip_vl : SDNode<"RISCVISD::VNCLIP_VL", SDT_RISCVVNBinOp_RM_VL>;
+
 def riscv_fadd_vl  : SDNode<"RISCVISD::FADD_VL",  SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
 def riscv_fsub_vl  : SDNode<"RISCVISD::FSUB_VL",  SDT_RISCVFPBinOp_VL>;
 def riscv_fmul_vl  : SDNode<"RISCVISD::FMUL_VL",  SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
@@ -635,6 +650,34 @@ class VPatBinaryVL_V<SDPatternOperator vop,
                    op2_reg_class:$rs2,
                    (mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
 
+multiclass VPatBinaryRM_VL_V<SDNode vop,
+                             string instruction_name,
+                             string suffix,
+                             ValueType result_type,
+                             ValueType op1_type,
+                             ValueType op2_type,
+                             ValueType mask_type,
+                             int sew,
+                             LMULInfo vlmul,
+                             VReg result_reg_class,
+                             VReg op1_reg_class,
+                             VReg op2_reg_class> {
+  def : Pat<(result_type (vop
+                         (op1_type op1_reg_class:$rs1),
+                         (op2_type op2_reg_class:$rs2),
+                         (result_type result_reg_class:$merge),
+                         (mask_type V0),
+                         (XLenVT timm:$roundmode),
+                         VLOpFrag)),
+        (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK")
+                     result_reg_class:$merge,
+                     op1_reg_class:$rs1,
+                     op2_reg_class:$rs2,
+                     (mask_type V0),
+                     (XLenVT timm:$roundmode),
+                     GPR:$vl, sew, TAIL_AGNOSTIC)>;
+}
+
 class VPatBinaryVL_V_RM<SDPatternOperator vop,
                         string instruction_name,
                         string suffix,
@@ -795,6 +838,35 @@ class VPatBinaryVL_XI<SDPatternOperator vop,
                    xop_kind:$rs2,
                    (mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
 
+multiclass VPatBinaryRM_VL_XI<SDNode vop,
+                              string instruction_name,
+                              string suffix,
+                              ValueType result_type,
+                              ValueType vop1_type,
+                              ValueType vop2_type,
+                              ValueType mask_type,
+                              int sew,
+                              LMULInfo vlmul,
+                              VReg result_reg_class,
+                              VReg vop_reg_class,
+                              ComplexPattern SplatPatKind,
+                              DAGOperand xop_kind> {
+  def : Pat<(result_type (vop
+                     (vop1_type vop_reg_class:$rs1),
+                     (vop2_type (SplatPatKind (XLenVT xop_kind:$rs2))),
+                     (result_type result_reg_class:$merge),
+                     (mask_type V0),
+                     (XLenVT timm:$roundmode),
+                     VLOpFrag)),
+        (!cast<Instruction>(instruction_name#_#suffix#_# vlmul.MX#"_MASK")
+                     result_reg_class:$merge,
+                     vop_reg_class:$rs1,
+                     xop_kind:$rs2,
+                     (mask_type V0),
+                     (XLenVT timm:$roundmode),
+                     GPR:$vl, sew, TAIL_AGNOSTIC)>;
+}
+
 multiclass VPatBinaryVL_VV_VX<SDPatternOperator vop, string instruction_name,
                               list<VTypeInfo> vtilist = AllIntegerVectors,
                               bit isSEWAware = 0> {
@@ -893,6 +965,24 @@ multiclass VPatBinaryNVL_WV_WX_WI<SDPatternOperator vop, string instruction_name
   }
 }
 
+multiclass VPatBinaryRM_NVL_WV_WX_WI<SDNode vop, string instruction_name> {
+  foreach VtiToWti = AllWidenableIntVectors in {
+    defvar vti = VtiToWti.Vti;
+    defvar wti = VtiToWti.Wti;
+    defm : VPatBinaryRM_VL_V<vop, instruction_name, "WV",
+                             vti.Vector, wti.Vector, vti.Vector, vti.Mask,
+                             vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, vti.RegClass>;
+    defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WX",
+                              vti.Vector, wti.Vector, vti.Vector, vti.Mask,
+                              vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, SplatPat, GPR>;
+    defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WI",
+                              vti.Vector, wti.Vector, vti.Vector, vti.Mask,
+                              vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
+                              !cast<ComplexPattern>(SplatPat#_#uimm5),
+                              uimm5>;
+  }
+}
+
 class VPatBinaryVL_VF<SDPatternOperator vop,
                       string instruction_name,
                       ValueType result_type,
@@ -2376,6 +2466,10 @@ defm : VPatAVGADDVL_VV_VX_RM<riscv_avgflooru_vl, 0b10, suffix="U">;
 defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceils_vl, 0b00>;
 defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceilu_vl, 0b00, suffix="U">;
 
+// 12.5. Vector Narrowing Fixed-Point Clip Instructions
+defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclip_vl, "PseudoVNCLIP">;
+defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclipu_vl, "PseudoVNCLIPU">;
+
 // 12.5. Vector Narrowing Fixed-Point Clip Instructions
 multiclass VPatTruncSatClipVL<VTypeInfo vti, VTypeInfo wti> {
   defvar sew = vti.SEW;
@@ -2410,16 +2504,6 @@ multiclass VPatTruncSatClipVL<VTypeInfo vti, VTypeInfo wti> {
       (!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
         (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
         (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
-
-    def : Pat<(vti.Vector (riscv_trunc_vector_vl
-        (wti.Vector (riscv_umin_vl
-          (wti.Vector wti.RegClass:$rs1),
-          (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), uminval, (XLenVT srcvalue))),
-          (wti.Vector undef), (wti.Mask V0), VLOpFrag)),
-        (vti.Mask V0), VLOpFrag)),
-      (!cast<Instruction>("PseudoVNCLIPU_WI_"#vti.LMul.MX#"_MASK")
-        (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
-        (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
   }
 }
 

Copy link
Collaborator

@preames preames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM w/minor comments.


auto IsSplat = [&VL](SDValue Op, APInt &SplatVal) {
// Peek through conversion between fixed and scalable vectors.
if (Op.getOpcode() == ISD::INSERT_SUBVECTOR && Op.getOperand(0).isUndef() &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this just a nop? Why aren't we combining that out?

Copy link
Collaborator Author

@topperc topperc May 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do eventually, but we only revisit the UMIN after that since its the immediate user. We don't revisit the truncate after that.

I think this is partially caused by the truncate and umin being legalized by LegalizeVectorOps and build_vector being legalized by LegalizeDAG.

Src.getOperand(3) == Mask && Src.getOperand(4) == VL))
return SDValue();

auto IsSplat = [&VL](SDValue Op, APInt &SplatVal) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a general useful function, surely we have parts of this repeated elsewhere already?

if (ISD::isConstantSplatVector(Op.getNode(), SplatVal))
return true;

if (Op.getOpcode() == RISCVISD::VMV_V_X_VL && Op.getOperand(0).isUndef() &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add vmv_s_x when VL=1?

Copy link
Member

@sun-jacobi sun-jacobi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

I plan to add support for multiple layers of vnclipu. For example,
i32->i8 using 2 vnclipu instructions. First clipping to 65535, then
clipping to 255. Similar for signed vnclip.

This scales poorly if we need to add patterns with 2 or 3 truncates.
Instead, move the code to DAGCombiner with new ISD opcodes to represent
VCLIP(U).

This patch just moves the existing patterns into DAG combine. Support
for multiple truncates will as a follow up. A similar patch will be
made for the signed vnclip patterns.
@topperc topperc force-pushed the pr/vnclipu-pattern branch from 9d7f119 to 20f81eb Compare May 29, 2024 19:08
@topperc topperc merged commit ec8fe59 into llvm:main May 29, 2024
4 of 6 checks passed
@topperc topperc deleted the pr/vnclipu-pattern branch May 29, 2024 20:00
topperc added a commit to topperc/llvm-project that referenced this pull request May 29, 2024
Similar to llvm#93596, this moves the signed vnclip patterns into DAG combine.

This will allows us to support more than 1 level of truncate in a
future patch.
topperc added a commit to topperc/llvm-project that referenced this pull request May 29, 2024
Similar to llvm#93596, this moves the signed vnclip patterns into DAG combine.

This will allows us to support more than 1 level of truncate in a
future patch.
topperc added a commit that referenced this pull request May 29, 2024
Similar to #93596, this moves the signed vnclip patterns into DAG
combine.
    
This will allows us to support more than 1 level of truncate in a
future patch.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants