Skip to content

[VP][RISCV] Add vp.cttz.elts intrinsic and its RISC-V codegen #90502

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 30, 2024

Conversation

mshockwave
Copy link
Member

@mshockwave mshockwave commented Apr 29, 2024

This intrinsic is the VP version of experimental.cttz.elts.

@llvmbot
Copy link
Member

llvmbot commented Apr 29, 2024

@llvm/pr-subscribers-llvm-selectiondag
@llvm/pr-subscribers-backend-risc-v

@llvm/pr-subscribers-llvm-ir

Author: Min-Yih Hsu (mshockwave)

Changes

This intrinsic is basically the VP version of experimental.cttz.elts.


Patch is 27.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/90502.diff

13 Files Affected:

  • (modified) llvm/docs/LangRef.rst (+48)
  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+5)
  • (modified) llvm/include/llvm/IR/Intrinsics.td (+6)
  • (modified) llvm/include/llvm/IR/VPIntrinsics.def (+9)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+10)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+3)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+7)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+61)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+8-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+32)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+46-1)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+1)
  • (added) llvm/test/CodeGen/RISCV/rvv/vp-cttz-elts.ll (+204)
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 37662f79145d67..f79c1fd9278de3 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -24001,6 +24001,54 @@ Examples:
       %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> poison
 
 
+.. _int_vp_cttz_elts:
+
+'``llvm.vp.cttz.elts.*``' Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic. You can use ```llvm.vp.cttz.elts``` on any
+vector of integer elements, both fixed width and scalable.
+
+::
+
+      declare i32  @llvm.vp.cttz.elts.i32.v16i32 (<16 x i32> <op>, i1 <is_zero_poison>, <16 x i1> <mask>, i32 <vector_length>)
+      declare i64  @llvm.vp.cttz.elts.i64.nxv4i32 (<vscale x 4 x i32> <op>, i1 <is_zero_poison>, <vscale x 4 x i1> <mask>, i32 <vector_length>)
+      declare i64  @llvm.vp.cttz.elts.i64.v256i1 (<256 x i1> <op>, i1 <is_zero_poison>, <256 x i1> <mask>, i32 <vector_length>)
+
+Overview:
+"""""""""
+
+This '```llvm.vp.cttz.elts```' intrinsic counts the number of trailing zero
+elements of a vector. This is basically the vector-predicated version of
+'```llvm.experimental.cttz.elts```'.
+
+Arguments:
+""""""""""
+
+The first argument is the vector to be counted. This argument must be a vector
+with integer element type. The return type must also be an integer type which is
+wide enough to hold the maximum number of elements of the source vector. The
+behavior of this intrinsic is undefined if the return type is not wide enough
+for the number of elements in the input vector.
+
+The second argument is a constant flag that indicates whether the intrinsic
+returns a valid result if the first argument is all zero.
+
+The third operand is the vector mask and has the same number of elements as the
+input vector type. The fourth operand is the explicit vector length of the
+operation.
+
+Semantics:
+""""""""""
+
+The '``llvm.vp.cttz.elts``' intrinsic counts the trailing (least
+significant / lowest-numbered) zero elements in the first operand on each
+enabled lane. If the first argument is all zero and the second argument is true,
+the result is poison. Otherwise, it returns the explicit vector length (i.e. the
+fourth operand).
+
 .. _int_vp_sadd_sat:
 
 '``llvm.vp.sadd.sat.*``' Intrinsics
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 661b2841c6ac72..7ed08cfa8a2022 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5307,6 +5307,11 @@ class TargetLowering : public TargetLoweringBase {
   /// \returns The expansion result or SDValue() if it fails.
   SDValue expandVPCTTZ(SDNode *N, SelectionDAG &DAG) const;
 
+  /// Expand VP_CTTZ_ELTS/VP_CTTZ_ELTS_ZERO_UNDEF nodes.
+  /// \param N Node to expand
+  /// \returns The expansion result or SDValue() if it fails.
+  SDValue expandVPCTTZElements(SDNode *N, SelectionDAG &DAG) const;
+
   /// Expand ABS nodes. Expands vector/scalar ABS nodes,
   /// vector nodes can only succeed if all operations are legal/custom.
   /// (ABS x) -> (XOR (ADD x, (SRA x, type_size)), (SRA x, type_size))
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index a2678d69ce4062..28116e5316c96b 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -2255,6 +2255,12 @@ let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn, ImmArg<ArgIndex<1>>
                                llvm_i1_ty,
                                LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
                                llvm_i32_ty]>;
+
+  def int_vp_cttz_elts : DefaultAttrsIntrinsic<[ llvm_anyint_ty ],
+                                  [ llvm_anyvector_ty,
+                                    llvm_i1_ty,
+                                    LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>,
+                                    llvm_i32_ty]>;
 }
 
 def int_get_active_lane_mask:
diff --git a/llvm/include/llvm/IR/VPIntrinsics.def b/llvm/include/llvm/IR/VPIntrinsics.def
index 1c2708a9e85437..f1cc8bcae467be 100644
--- a/llvm/include/llvm/IR/VPIntrinsics.def
+++ b/llvm/include/llvm/IR/VPIntrinsics.def
@@ -282,6 +282,15 @@ BEGIN_REGISTER_VP_SDNODE(VP_CTTZ_ZERO_UNDEF, -1, vp_cttz_zero_undef, 1, 2)
 END_REGISTER_VP_SDNODE(VP_CTTZ_ZERO_UNDEF)
 END_REGISTER_VP_INTRINSIC(vp_cttz)
 
+// llvm.vp.cttz.elts(x,is_zero_poison,mask,vl)
+BEGIN_REGISTER_VP_INTRINSIC(vp_cttz_elts, 2, 3)
+VP_PROPERTY_NO_FUNCTIONAL
+BEGIN_REGISTER_VP_SDNODE(VP_CTTZ_ELTS, 0, vp_cttz_elts, 1, 2)
+END_REGISTER_VP_SDNODE(VP_CTTZ_ELTS)
+BEGIN_REGISTER_VP_SDNODE(VP_CTTZ_ELTS_ZERO_UNDEF, 0, vp_cttz_elts_zero_undef, 1, 2)
+END_REGISTER_VP_SDNODE(VP_CTTZ_ELTS_ZERO_UNDEF)
+END_REGISTER_VP_INTRINSIC(vp_cttz_elts)
+
 // llvm.vp.fshl(x,y,z,mask,vlen)
 BEGIN_REGISTER_VP(vp_fshl, 3, 4, VP_FSHL, -1)
 VP_PROPERTY_FUNCTIONAL_INTRINSIC(fshl)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 55f9737bc94dd5..0aa36deda79dcc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -76,6 +76,10 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::VP_CTTZ:
   case ISD::CTTZ_ZERO_UNDEF:
   case ISD::CTTZ:        Res = PromoteIntRes_CTTZ(N); break;
+  case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+  case ISD::VP_CTTZ_ELTS:
+    Res = PromoteIntRes_VP_CttzElements(N);
+    break;
   case ISD::EXTRACT_VECTOR_ELT:
                          Res = PromoteIntRes_EXTRACT_VECTOR_ELT(N); break;
   case ISD::LOAD:        Res = PromoteIntRes_LOAD(cast<LoadSDNode>(N)); break;
@@ -724,6 +728,12 @@ SDValue DAGTypeLegalizer::PromoteIntRes_CTTZ(SDNode *N) {
                      N->getOperand(2));
 }
 
+SDValue DAGTypeLegalizer::PromoteIntRes_VP_CttzElements(SDNode *N) {
+  SDLoc DL(N);
+  EVT NewVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+  return DAG.getNode(N->getOpcode(), DL, NewVT, N->ops());
+}
+
 SDValue DAGTypeLegalizer::PromoteIntRes_EXTRACT_VECTOR_ELT(SDNode *N) {
   SDLoc dl(N);
   EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 4a2c7b355eb528..49be824deb5134 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -309,6 +309,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntRes_CTLZ(SDNode *N);
   SDValue PromoteIntRes_CTPOP_PARITY(SDNode *N);
   SDValue PromoteIntRes_CTTZ(SDNode *N);
+  SDValue PromoteIntRes_VP_CttzElements(SDNode *N);
   SDValue PromoteIntRes_EXTRACT_VECTOR_ELT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_XINT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_XINT_SAT(SDNode *N);
@@ -912,6 +913,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue SplitVecOp_FP_ROUND(SDNode *N);
   SDValue SplitVecOp_FPOpDifferentTypes(SDNode *N);
   SDValue SplitVecOp_FP_TO_XINT_SAT(SDNode *N);
+  SDValue SplitVecOp_VP_CttzElements(SDNode *N);
 
   //===--------------------------------------------------------------------===//
   // Vector Widening Support: LegalizeVectorTypes.cpp
@@ -1019,6 +1021,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue WidenVecOp_VECREDUCE_SEQ(SDNode *N);
   SDValue WidenVecOp_VP_REDUCE(SDNode *N);
   SDValue WidenVecOp_ExpOp(SDNode *N);
+  SDValue WidenVecOp_VP_CttzElements(SDNode *N);
 
   /// Helper function to generate a set of operations to perform
   /// a vector operation for a wider type.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 8f87ee8e09393a..154e78075eb593 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -1026,6 +1026,13 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
       return;
     }
     break;
+  case ISD::VP_CTTZ_ELTS:
+  case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+    if (SDValue Expanded = TLI.expandVPCTTZElements(Node, DAG)) {
+      Results.push_back(Expanded);
+      return;
+    }
+    break;
   case ISD::FSHL:
   case ISD::VP_FSHL:
   case ISD::FSHR:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 985c9f16ab97cd..c1865b2b7d59a8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -22,6 +22,7 @@
 #include "LegalizeTypes.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/Analysis/MemoryLocation.h"
+#include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -3098,6 +3099,10 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::VP_REDUCE_FMIN:
     Res = SplitVecOp_VP_REDUCE(N, OpNo);
     break;
+  case ISD::VP_CTTZ_ELTS:
+  case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+    Res = SplitVecOp_VP_CttzElements(N);
+    break;
   }
 
   // If the result is null, the sub-method took care of registering results etc.
@@ -4056,6 +4061,29 @@ SDValue DAGTypeLegalizer::SplitVecOp_FP_TO_XINT_SAT(SDNode *N) {
   return DAG.getNode(ISD::CONCAT_VECTORS, dl, ResVT, Lo, Hi);
 }
 
+SDValue DAGTypeLegalizer::SplitVecOp_VP_CttzElements(SDNode *N) {
+  SDLoc DL(N);
+  EVT ResVT = N->getValueType(0);
+
+  SDValue Lo, Hi;
+  SDValue VecOp = N->getOperand(0);
+  GetSplitVector(VecOp, Lo, Hi);
+
+  auto [MaskLo, MaskHi] = SplitMask(N->getOperand(1));
+  auto [EVLLo, EVLHi] =
+      DAG.SplitEVL(N->getOperand(2), VecOp.getValueType(), DL);
+  SDValue VLo = DAG.getZExtOrTrunc(EVLLo, DL, ResVT);
+
+  // if VP_CTTZ_ELTS(Lo) != EVLLo => VP_CTTZ_ELTS(Lo).
+  // else => EVLLo + VP_CTTZ_ELTS / VP_CTTZ_ELTS_ZERO_UNDEF(Hi).
+  SDValue ResLo = DAG.getNode(ISD::VP_CTTZ_ELTS, DL, ResVT, Lo, MaskLo, EVLLo);
+  SDValue ResLoNotEVL =
+      DAG.getSetCC(DL, getSetCCResultType(ResVT), ResLo, VLo, ISD::SETNE);
+  SDValue ResHi = DAG.getNode(N->getOpcode(), DL, ResVT, Hi, MaskHi, EVLHi);
+  return DAG.getSelect(DL, ResVT, ResLoNotEVL, ResLo,
+                       DAG.getNode(ISD::ADD, DL, ResVT, VLo, ResHi));
+}
+
 //===----------------------------------------------------------------------===//
 //  Result Vector Widening
 //===----------------------------------------------------------------------===//
@@ -6161,6 +6189,10 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::VP_REDUCE_FMIN:
     Res = WidenVecOp_VP_REDUCE(N);
     break;
+  case ISD::VP_CTTZ_ELTS:
+  case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+    Res = WidenVecOp_VP_CttzElements(N);
+    break;
   }
 
   // If Res is null, the sub-method took care of registering the result.
@@ -6924,6 +6956,35 @@ SDValue DAGTypeLegalizer::WidenVecOp_VSELECT(SDNode *N) {
                      DAG.getVectorIdxConstant(0, DL));
 }
 
+SDValue DAGTypeLegalizer::WidenVecOp_VP_CttzElements(SDNode *N) {
+  // Widen the result bit width if needed.
+  SDLoc DL(N);
+  EVT OrigResVT = N->getValueType(0);
+  const Function &F = DAG.getMachineFunction().getFunction();
+
+  SDValue Source = GetWidenedVector(N->getOperand(0));
+  EVT SrcVT = Source.getValueType();
+  SDValue Mask =
+      GetWidenedMask(N->getOperand(1), SrcVT.getVectorElementCount());
+
+  // Compute the number of bits that can fit the result.
+  ConstantRange CR(APInt(64, SrcVT.getVectorMinNumElements()));
+  if (SrcVT.isScalableVT())
+    CR = CR.umul_sat(getVScaleRange(&F, 64));
+  // If the zero-is-poison flag is set in the original intrinsic, we can
+  // assume the upper limit of the result is EVL - 1.
+  if (N->getOpcode() == ISD::VP_CTTZ_ELTS_ZERO_UNDEF)
+    CR = CR.subtract(APInt(64, 1));
+
+  unsigned NewResWidth = OrigResVT.getScalarSizeInBits();
+  NewResWidth = std::min(NewResWidth, unsigned(CR.getActiveBits()));
+  NewResWidth = std::max(llvm::bit_ceil(NewResWidth), 8U);
+
+  EVT NewResVT = EVT::getIntegerVT(*DAG.getContext(), NewResWidth);
+  return DAG.getNode(N->getOpcode(), DL, NewResVT, Source, Mask,
+                     N->getOperand(2));
+}
+
 //===----------------------------------------------------------------------===//
 // Vector Widening Utilities
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 5caf868c83a296..cfd82a342433fa 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8076,6 +8076,11 @@ static unsigned getISDForVPIntrinsic(const VPIntrinsic &VPIntrin) {
     ResOPC = IsZeroUndef ? ISD::VP_CTTZ_ZERO_UNDEF : ISD::VP_CTTZ;
     break;
   }
+  case Intrinsic::vp_cttz_elts: {
+    bool IsZeroPoison = cast<ConstantInt>(VPIntrin.getArgOperand(1))->isOne();
+    ResOPC = IsZeroPoison ? ISD::VP_CTTZ_ELTS_ZERO_UNDEF : ISD::VP_CTTZ_ELTS;
+    break;
+  }
 #define HELPER_MAP_VPID_TO_VPSD(VPID, VPSD)                                    \
   case Intrinsic::VPID:                                                        \
     ResOPC = ISD::VPSD;                                                        \
@@ -8428,7 +8433,9 @@ void SelectionDAGBuilder::visitVectorPredicationIntrinsic(
   case ISD::VP_CTLZ:
   case ISD::VP_CTLZ_ZERO_UNDEF:
   case ISD::VP_CTTZ:
-  case ISD::VP_CTTZ_ZERO_UNDEF: {
+  case ISD::VP_CTTZ_ZERO_UNDEF:
+  case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+  case ISD::VP_CTTZ_ELTS: {
     SDValue Result =
         DAG.getNode(Opcode, DL, VTs, {OpValues[0], OpValues[2], OpValues[3]});
     setValue(&VPIntrin, Result);
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index cdc1227fd572dc..a6a14a21dd80d5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -9074,6 +9074,38 @@ SDValue TargetLowering::expandVPCTTZ(SDNode *Node, SelectionDAG &DAG) const {
   return DAG.getNode(ISD::VP_CTPOP, dl, VT, Tmp, Mask, VL);
 }
 
+SDValue TargetLowering::expandVPCTTZElements(SDNode *N,
+                                             SelectionDAG &DAG) const {
+  // %cond = to_bool_vec %source
+  // %splat = splat /*val=*/VL
+  // %tz = step_vector
+  // %v = vp.select %cond, /*true=*/tz, /*false=*/%splat
+  // %r = vp.reduce.umin %v
+  SDLoc DL(N);
+  SDValue Source = N->getOperand(0);
+  SDValue Mask = N->getOperand(1);
+  SDValue EVL = N->getOperand(2);
+  EVT SrcVT = Source.getValueType();
+  EVT ResVT = N->getValueType(0);
+  EVT ResVecVT =
+      EVT::getVectorVT(*DAG.getContext(), ResVT, SrcVT.getVectorElementCount());
+
+  // Convert to boolean vector.
+  if (SrcVT.getScalarType() != MVT::i1) {
+    SDValue AllZero = DAG.getConstant(0, DL, SrcVT);
+    SrcVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+                             SrcVT.getVectorElementCount());
+    Source = DAG.getSetCC(DL, SrcVT, Source, AllZero, ISD::SETNE);
+  }
+
+  SDValue ExtEVL = DAG.getZExtOrTrunc(EVL, DL, ResVT);
+  SDValue Splat = DAG.getSplat(ResVecVT, DL, ExtEVL);
+  SDValue StepVec = DAG.getStepVector(DL, ResVecVT);
+  SDValue Select =
+      DAG.getNode(ISD::VP_SELECT, DL, ResVecVT, Source, StepVec, Splat, EVL);
+  return DAG.getNode(ISD::VP_REDUCE_UMIN, DL, ResVT, ExtEVL, Select, Mask, EVL);
+}
+
 SDValue TargetLowering::expandABS(SDNode *N, SelectionDAG &DAG,
                                   bool IsNegative) const {
   SDLoc dl(N);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 68f4ec5ef49f31..d6bb4542bb3263 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -698,7 +698,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
         ISD::VP_SMAX,        ISD::VP_UMIN,        ISD::VP_UMAX,
         ISD::VP_ABS, ISD::EXPERIMENTAL_VP_REVERSE, ISD::EXPERIMENTAL_VP_SPLICE,
         ISD::VP_SADDSAT,     ISD::VP_UADDSAT,     ISD::VP_SSUBSAT,
-        ISD::VP_USUBSAT};
+        ISD::VP_USUBSAT,     ISD::VP_CTTZ_ELTS,   ISD::VP_CTTZ_ELTS_ZERO_UNDEF};
 
     static const unsigned FloatingPointVPOps[] = {
         ISD::VP_FADD,        ISD::VP_FSUB,        ISD::VP_FMUL,
@@ -759,6 +759,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
           {ISD::SELECT_CC, ISD::VSELECT, ISD::VP_MERGE, ISD::VP_SELECT}, VT,
           Expand);
 
+      setOperationAction({ISD::VP_CTTZ_ELTS, ISD::VP_CTTZ_ELTS_ZERO_UNDEF}, VT,
+                         Custom);
+
       setOperationAction({ISD::VP_AND, ISD::VP_OR, ISD::VP_XOR}, VT, Custom);
 
       setOperationAction(
@@ -5341,6 +5344,45 @@ RISCVTargetLowering::lowerCTLZ_CTTZ_ZERO_UNDEF(SDValue Op,
   return Res;
 }
 
+SDValue RISCVTargetLowering::lowerVPCttzElements(SDValue Op,
+                                                 SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  MVT XLenVT = Subtarget.getXLenVT();
+  SDValue Source = Op->getOperand(0);
+  MVT SrcVT = Source.getSimpleValueType();
+  SDValue Mask = Op->getOperand(1);
+  SDValue EVL = Op->getOperand(2);
+
+  if (SrcVT.isFixedLengthVector()) {
+    MVT ContainerVT = getContainerForFixedLengthVector(SrcVT);
+    Source = convertToScalableVector(ContainerVT, Source, DAG, Subtarget);
+    Mask = convertToScalableVector(getMaskTypeFor(ContainerVT), Mask, DAG,
+                                   Subtarget);
+    SrcVT = ContainerVT;
+  }
+
+  // Convert to boolean vector.
+  if (SrcVT.getScalarType() != MVT::i1) {
+    SDValue AllZero = DAG.getConstant(0, DL, SrcVT);
+    SrcVT = MVT::getVectorVT(MVT::i1, SrcVT.getVectorElementCount());
+    Source = DAG.getNode(RISCVISD::SETCC_VL, DL, SrcVT,
+                         {Source, AllZero, DAG.getCondCode(ISD::SETNE),
+                          DAG.getUNDEF(SrcVT), Mask, EVL});
+  }
+
+  SDValue Res = DAG.getNode(RISCVISD::VFIRST_VL, DL, XLenVT, Source, Mask, EVL);
+  if (Op->getOpcode() == ISD::VP_CTTZ_ELTS_ZERO_UNDEF)
+    // In this case, we can interpret poison as -1, so nothing to do further.
+    return Res;
+
+  // Convert -1 to VL.
+  SDValue SetCC =
+      DAG.getSetCC(DL, XLenVT, Res, DAG.getConstant(0, DL, XLenVT), ISD::SETLT);
+  Res = DAG.getSelect(DL, XLenVT, SetCC, DAG.getZExtOrTrunc(EVL, DL, XLenVT),
+                      Res);
+  return DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), Res);
+}
+
 // While RVV has alignment restrictions, we should always be able to load as a
 // legal equivalently-sized byte-typed vector instead. This method is
 // responsible for re-expressing a ISD::LOAD via a correctly-aligned type. If
@@ -6595,6 +6637,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
     if (Op.getOperand(1).getValueType().getVectorElementType() == MVT::i1)
       return lowerVectorMaskVecReduction(Op, DAG, /*IsVP*/ true);
     return lowerVPREDUCE(Op, DAG);
+  case ISD::VP_CTTZ_ELTS:
+  case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+    return lowerVPCttzElements(Op, DAG);
   case ISD::UNDEF: {
     MVT ContainerVT = getContainerForFixedLengthVector(Op.getSimpleValueType());
     return convertFromScalableVector(Op.getSimpleValueType(),
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index ed14fd4539438a..78f99e70c083a7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -961,6 +961,7 @@ class RISCVTargetLowering : public TargetLowering {
   SDValue lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerVPStridedLoad(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerVPStridedStore(SDValue Op, SelectionDAG &DAG) const;
+  SDValue lowerVPCttzElements(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerFixedLengthVectorExtendToRVV(SDValue Op, SelectionDAG &DAG,
                                             unsigned ExtendOpc) const;
   SDValue lowerGET_ROUNDING(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/RISCV/rvv/vp-cttz-elts.ll b/llvm/test/CodeGen/RISCV/rvv/vp-cttz-elts.ll
new file mode 100644
index 00000000000000..77678361932c3a
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/vp-cttz-elts.ll
@@ -0,0 +1,204 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_chec...
[truncated]

@mshockwave
Copy link
Member Author

I wish I could test the expansion logics (i.e. fallback) for this intrinsics but vfirst is in every RVV extensions and AArch64 doesn't really support the combination of VP intrinsics + scalable vectors.

Copy link

github-actions bot commented Apr 29, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff d566a5cd22b4a653f10698f90c691a1452dad5ce a840bd236dec83872b633ee266f96688eaa08d50 -- llvm/include/llvm/CodeGen/TargetLowering.h llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp llvm/lib/Target/RISCV/RISCVISelLowering.cpp llvm/lib/Target/RISCV/RISCVISelLowering.h
View the diff from clang-format here.
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 1d3b1fcefc..a0c05995ad 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -684,21 +684,48 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setOperationAction({ISD::INTRINSIC_W_CHAIN, ISD::INTRINSIC_VOID},
                        MVT::Other, Custom);
 
-    static const unsigned IntegerVPOps[] = {
-        ISD::VP_ADD,         ISD::VP_SUB,         ISD::VP_MUL,
-        ISD::VP_SDIV,        ISD::VP_UDIV,        ISD::VP_SREM,
-        ISD::VP_UREM,        ISD::VP_AND,         ISD::VP_OR,
-        ISD::VP_XOR,         ISD::VP_ASHR,        ISD::VP_LSHR,
-        ISD::VP_SHL,         ISD::VP_REDUCE_ADD,  ISD::VP_REDUCE_AND,
-        ISD::VP_REDUCE_OR,   ISD::VP_REDUCE_XOR,  ISD::VP_REDUCE_SMAX,
-        ISD::VP_REDUCE_SMIN, ISD::VP_REDUCE_UMAX, ISD::VP_REDUCE_UMIN,
-        ISD::VP_MERGE,       ISD::VP_SELECT,      ISD::VP_FP_TO_SINT,
-        ISD::VP_FP_TO_UINT,  ISD::VP_SETCC,       ISD::VP_SIGN_EXTEND,
-        ISD::VP_ZERO_EXTEND, ISD::VP_TRUNCATE,    ISD::VP_SMIN,
-        ISD::VP_SMAX,        ISD::VP_UMIN,        ISD::VP_UMAX,
-        ISD::VP_ABS, ISD::EXPERIMENTAL_VP_REVERSE, ISD::EXPERIMENTAL_VP_SPLICE,
-        ISD::VP_SADDSAT,     ISD::VP_UADDSAT,     ISD::VP_SSUBSAT,
-        ISD::VP_USUBSAT,     ISD::VP_CTTZ_ELTS,   ISD::VP_CTTZ_ELTS_ZERO_UNDEF};
+    static const unsigned IntegerVPOps[] = {ISD::VP_ADD,
+                                            ISD::VP_SUB,
+                                            ISD::VP_MUL,
+                                            ISD::VP_SDIV,
+                                            ISD::VP_UDIV,
+                                            ISD::VP_SREM,
+                                            ISD::VP_UREM,
+                                            ISD::VP_AND,
+                                            ISD::VP_OR,
+                                            ISD::VP_XOR,
+                                            ISD::VP_ASHR,
+                                            ISD::VP_LSHR,
+                                            ISD::VP_SHL,
+                                            ISD::VP_REDUCE_ADD,
+                                            ISD::VP_REDUCE_AND,
+                                            ISD::VP_REDUCE_OR,
+                                            ISD::VP_REDUCE_XOR,
+                                            ISD::VP_REDUCE_SMAX,
+                                            ISD::VP_REDUCE_SMIN,
+                                            ISD::VP_REDUCE_UMAX,
+                                            ISD::VP_REDUCE_UMIN,
+                                            ISD::VP_MERGE,
+                                            ISD::VP_SELECT,
+                                            ISD::VP_FP_TO_SINT,
+                                            ISD::VP_FP_TO_UINT,
+                                            ISD::VP_SETCC,
+                                            ISD::VP_SIGN_EXTEND,
+                                            ISD::VP_ZERO_EXTEND,
+                                            ISD::VP_TRUNCATE,
+                                            ISD::VP_SMIN,
+                                            ISD::VP_SMAX,
+                                            ISD::VP_UMIN,
+                                            ISD::VP_UMAX,
+                                            ISD::VP_ABS,
+                                            ISD::EXPERIMENTAL_VP_REVERSE,
+                                            ISD::EXPERIMENTAL_VP_SPLICE,
+                                            ISD::VP_SADDSAT,
+                                            ISD::VP_UADDSAT,
+                                            ISD::VP_SSUBSAT,
+                                            ISD::VP_USUBSAT,
+                                            ISD::VP_CTTZ_ELTS,
+                                            ISD::VP_CTTZ_ELTS_ZERO_UNDEF};
 
     static const unsigned FloatingPointVPOps[] = {
         ISD::VP_FADD,        ISD::VP_FSUB,        ISD::VP_FMUL,

@topperc
Copy link
Collaborator

topperc commented Apr 29, 2024

I wish I could test the expansion logics (i.e. fallback) for this intrinsics but vfirst is in every RVV extensions and AArch64 doesn't really support the combination of VP intrinsics + scalable vectors.

Can you test it with fixed vectors on another target?

if (N->getOpcode() == ISD::VP_CTTZ_ELTS_ZERO_UNDEF)
CR = CR.subtract(APInt(64, 1));

unsigned NewResWidth = OrigResVT.getScalarSizeInBits();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a new result width? The result should still be bound by the original EVL.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I somehow mistakenly thought that EVL will also be widen. It's fixed now.

SDValue AllZero = DAG.getConstant(0, DL, SrcVT);
SrcVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
SrcVT.getVectorElementCount());
Source = DAG.getSetCC(DL, SrcVT, Source, AllZero, ISD::SETNE);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a VP_SETCC?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

// %splat = splat /*val=*/VL
// %tz = step_vector
// %v = vp.select %cond, /*true=*/tz, /*false=*/%splat
// %r = vp.reduce.umin %v
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this works. If element 0 and element 1 are both zero and element 2 is non-zero this will return 0, but we want to return 1.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If element 0 and element 1 are both zero and element 2 is non-zero

I thought in this case, we want to return 2 since there are two (i.e. element 0 and 1) zero lowest-numbered elements.

this will return 0, but we want to return 1.

The %cond is a boolean vector converted from the input, if an input element is non-zero, the corresponding result element in %v would be its index value (from %tz, which is a step_vector), which equals to the number of trailing elements. Otherwise, the result element would be EVL (from %splat, which is a splat of EVL). Finally we reduce.umin on %v so we can always find the lowest index value of non-zero elements, hence the number of trailing zero.

Copy link
Collaborator

@topperc topperc Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me look at again. I think I thought about it wrong.

Is there an opportunity to simplify the fall back code in SelectionDAGBuilder for experimental.cttz.elts? It uses umax.

// Convert -1 to VL.
SDValue SetCC =
DAG.getSetCC(DL, XLenVT, Res, DAG.getConstant(0, DL, XLenVT), ISD::SETLT);
Res = DAG.getSelect(DL, XLenVT, SetCC, DAG.getZExtOrTrunc(EVL, DL, XLenVT),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think EVL will already be XLenVT at this point.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@michaelmaitland
Copy link
Contributor

This intrinsic is basically ...

Is there anything that makes it basically the vp version and not just the vp version?

@mshockwave
Copy link
Member Author

This intrinsic is basically ...

Is there anything that makes it basically the vp version and not just the vp version?

yeah you're right, I'd removed the word "basically".

SDValue VLo = DAG.getZExtOrTrunc(EVLLo, DL, ResVT);

// if VP_CTTZ_ELTS(Lo) != EVLLo => VP_CTTZ_ELTS(Lo).
// else => EVLLo + VP_CTTZ_ELTS / VP_CTTZ_ELTS_ZERO_UNDEF(Hi).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of '/' in an arithmetic expression makes this look like division when first reading it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fixed.

@mshockwave
Copy link
Member Author

mshockwave commented Apr 29, 2024

I wish I could test the expansion logics (i.e. fallback) for this intrinsics but vfirst is in every RVV extensions and AArch64 doesn't really support the combination of VP intrinsics + scalable vectors.

Can you test it with fixed vectors on another target?

I tried with something like

define i32 @fixed_v4i32(<4 x i32> %src, <4 x i1> %m, i32 %evl) {
   %r = call i32 @llvm.vp.cttz.elts.i32.v4i32(<4 x i32> %src, i1 0, <4 x i1> %m, i32 %evl)
   ret i32 %r
}

With AArch64 (both with and without SVE). And it failed during type legalization when it tried to promote EVL from i32 to i64, which I thought the legalizer is not suppose to try. I'm not sure if I'm missing any flags

; RV64-NEXT: mv a0, a1
; RV64-NEXT: .LBB0_2:
; RV64-NEXT: ret
%r = call iXLen @llvm.vp.cttz.elts.iXLen.nxv2i1(<vscale x 2 x i1> %src, i1 0, <vscale x 2 x i1> %m, i32 %evl)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add test cases that compare the result to %evl and return an i1. I suspect we'll need to optimize that, but maybe we'll get lucky.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. And I think we're lucky here: the comparison is lowed to checking whether the result of vfirst is less than zero.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that the nxv2i32_cmp_evl test? Isn't the seqz comparing the result of the select to evl?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that the nxv2i32_cmp_evl test? Isn't the seqz comparing the result of the select to evl?

I think you're right.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I sent a PR to fix this: #90538. I'll add more tests tomorrow and open for review.

@topperc
Copy link
Collaborator

topperc commented Apr 29, 2024

I wish I could test the expansion logics (i.e. fallback) for this intrinsics but vfirst is in every RVV extensions and AArch64 doesn't really support the combination of VP intrinsics + scalable vectors.

Can you test it with fixed vectors on another target?

I tried with something like

define i32 @fixed_v4i32(<4 x i32> %src, <4 x i1> %m, i32 %evl) {
   %r = call i32 @llvm.vp.cttz.elts.i32.v4i32(<4 x i32> %src, i1 0, <4 x i1> %m, i32 %evl)
   ret i32 %r
}

With AArch64 (both with and without SVE). And it failed during type legalization when it tried to promote EVL from i32 to i64, which I thought the legalizer is not suppose to try. I'm not sure if I'm missing any flags

Are you sure its failing on the EVL and not the i1 mask?

@@ -1026,6 +1026,13 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
return;
}
break;
case ISD::VP_CTTZ_ELTS:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should actually be expanded in LegalizeDAG. The original intent of LegalizeVectorOps is to handle instructions that may need to scalarize. Anything that is exclusively vector with no scalar equivalent should be done by LegalizeDAG.

Unfortunately, the code for VP intrinsics in VectorLegalizer::LegalizeOp where we instantiate BEGIN_REGISTER_VP_SDNODE is completely unaware of this. I think we should force Legal for this operation there so that it will get ignored until LegalizeDAG.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@mshockwave
Copy link
Member Author

I wish I could test the expansion logics (i.e. fallback) for this intrinsics but vfirst is in every RVV extensions and AArch64 doesn't really support the combination of VP intrinsics + scalable vectors.

Can you test it with fixed vectors on another target?

I tried with something like

define i32 @fixed_v4i32(<4 x i32> %src, <4 x i1> %m, i32 %evl) {
   %r = call i32 @llvm.vp.cttz.elts.i32.v4i32(<4 x i32> %src, i1 0, <4 x i1> %m, i32 %evl)
   ret i32 %r
}

With AArch64 (both with and without SVE). And it failed during type legalization when it tried to promote EVL from i32 to i64, which I thought the legalizer is not suppose to try. I'm not sure if I'm missing any flags

Are you sure its failing on the EVL and not the i1 mask?

My bad, it fails on the mask operand.

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mshockwave
Copy link
Member Author

I'll merge this after #90522

@mshockwave
Copy link
Member Author

Forced push to rebase.

@mshockwave mshockwave merged commit 539f626 into llvm:main Apr 30, 2024
@mshockwave mshockwave deleted the patch/vp-cttz-elts branch April 30, 2024 16:27
mshockwave added a commit that referenced this pull request Apr 30, 2024
Forgot to add vp.cttz.elts into the unittest. Also, I didn't specify the
positions of overloaded type parameters.
@joker-eph
Copy link
Collaborator

This broke the bot: https://lab.llvm.org/buildbot/#/builders/272 ; is there already a fix or should we revert here?

@mshockwave
Copy link
Member Author

mshockwave commented Apr 30, 2024

This broke the bot: https://lab.llvm.org/buildbot/#/builders/272 ; is there already a fix or should we revert here?

It has been fixed in 6ab49fc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants