Skip to content

[AArch64] Add MATCH loops to LoopIdiomVectorizePass #101976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 10, 2025

Conversation

rj-jesus
Copy link
Contributor

@rj-jesus rj-jesus commented Aug 5, 2024

This patch adds a new loop to LoopIdiomVectorizePass, enabling it to
recognise and vectorise loops such as:

template<class InputIt, class ForwardIt>
InputIt find_first_of(InputIt first, InputIt last,
                      ForwardIt s_first, ForwardIt s_last)
{
  for (; first != last; ++first)
    for (ForwardIt it = s_first; it != s_last; ++it)
      if (*first == *it)
        return first;
  return last;
}

These loops match the C++ standard library function std::find_first_of.

The loops are vectorised using @experimental.vector.match in #101974.

@llvmbot llvmbot added backend:AArch64 vectorizers llvm:SelectionDAG SelectionDAGISel as well llvm:ir llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Aug 5, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 5, 2024

@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-selectiondag

Author: Ricardo Jesus (rj-jesus)

Changes

This patch adds a new loop to LoopIdiomVectorizePass, enabling it to
recognise and vectorise loops such as:

template&lt;class InputIt, class ForwardIt&gt;
InputIt find_first_of(InputIt first, InputIt last,
                      ForwardIt s_first, ForwardIt s_last)
{
  for (; first != last; ++first)
    for (ForwardIt it = s_first; it != s_last; ++it)
      if (*first == *it)
        return first;
  return last;
}

These loops match the C++ standard library function std::find_first_of.

The loops are vectorised using @<!-- -->experimental.vector.match in #101974.


Patch is 41.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/101976.diff

12 Files Affected:

  • (modified) llvm/docs/LangRef.rst (+45)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+9)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+2)
  • (modified) llvm/include/llvm/IR/Intrinsics.td (+10)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+5)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+9)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+46)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+12)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+2)
  • (modified) llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp (+441-1)
  • (added) llvm/test/CodeGen/AArch64/find-first-byte.ll (+120)
  • (added) llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll (+57)
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index b17e3c828ed3d..dd9851d1af078 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -19637,6 +19637,51 @@ are undefined.
     }
 
 
+'``llvm.experimental.vector.match.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+This is an overloaded intrinsic. Support for specific vector types is target
+dependent.
+
+::
+
+    declare <<n> x i1> @llvm.experimental.vector.match(<<n> x <ty>> %op1, <<n> x <ty>> %op2, <<n> x i1> %mask, i32 <segsize>)
+    declare <vscale x <n> x i1> @llvm.experimental.vector.match(<vscale x <n> x <ty>> %op1, <vscale x <n> x <ty>> %op2, <vscale x <n> x i1> %mask, i32 <segsize>)
+
+Overview:
+"""""""""
+
+Find elements of the first argument matching any elements of the second.
+
+Arguments:
+""""""""""
+
+The first argument is the search vector, the second argument is the vector of
+elements we are searching for (i.e. for which we consider a match successful),
+and the third argument is a mask that controls which elements of the first
+argument are active. The fourth argument is an immediate that sets the segment
+size for the search window.
+
+Semantics:
+""""""""""
+
+The '``llvm.experimental.vector.match``' intrinsic compares each element in the
+first argument against potentially several elements of the second, placing
+``1`` in the corresponding element of the output vector if any comparison is
+successful, and ``0`` otherwise. Inactive elements in the mask are set to ``0``
+in the output. The segment size controls the number of elements of the second
+argument that are compared against.
+
+For example, for vectors with 16 elements, if ``segsize = 16`` then each
+element of the first argument is compared against all 16 elements of the second
+argument; but if ``segsize = 4``, then each of the first four elements of the
+first argument is compared against the first four elements of the second
+argument, each of the second four elements of the first argument is compared
+against the second four elements of the second argument, and so forth.
+
 Matrix Intrinsics
 -----------------
 
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 38e8b9da21397..786c13a177ccf 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1746,6 +1746,10 @@ class TargetTransformInfo {
   bool hasActiveVectorLength(unsigned Opcode, Type *DataType,
                              Align Alignment) const;
 
+  /// \returns Returns true if the target supports vector match operations for
+  /// the vector type `VT` using a segment size of `SegSize`.
+  bool hasVectorMatch(VectorType *VT, unsigned SegSize) const;
+
   struct VPLegalization {
     enum VPTransform {
       // keep the predicating parameter
@@ -2184,6 +2188,7 @@ class TargetTransformInfo::Concept {
   virtual bool supportsScalableVectors() const = 0;
   virtual bool hasActiveVectorLength(unsigned Opcode, Type *DataType,
                                      Align Alignment) const = 0;
+  virtual bool hasVectorMatch(VectorType *VT, unsigned SegSize) const = 0;
   virtual VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
   virtual bool hasArmWideBranch(bool Thumb) const = 0;
@@ -2952,6 +2957,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.hasActiveVectorLength(Opcode, DataType, Alignment);
   }
 
+  bool hasVectorMatch(VectorType *VT, unsigned SegSize) const override {
+    return Impl.hasVectorMatch(VT, SegSize);
+  }
+
   VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const override {
     return Impl.getVPLegalizationStrategy(PI);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index d208a710bb27f..36621861ab8c8 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -958,6 +958,8 @@ class TargetTransformInfoImplBase {
     return false;
   }
 
+  bool hasVectorMatch(VectorType *VT, unsigned SegSize) const { return false; }
+
   TargetTransformInfo::VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const {
     return TargetTransformInfo::VPLegalization(
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index b4e758136b39f..f6d77aa596f60 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1892,6 +1892,16 @@ def int_experimental_vector_histogram_add : DefaultAttrsIntrinsic<[],
                                LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], // Mask
                              [ IntrArgMemOnly ]>;
 
+// Experimental match
+def int_experimental_vector_match : DefaultAttrsIntrinsic<
+                             [ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ],
+                             [ llvm_anyvector_ty,
+                               LLVMMatchType<0>,
+                               LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,  // Mask
+                               llvm_i32_ty ],  // Segment size
+                             [ IntrNoMem, IntrNoSync, IntrWillReturn,
+                               ImmArg<ArgIndex<3>> ]>;
+
 // Operators
 let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn] in {
   // Integer arithmetic
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index dcde78925bfa9..d8314af0537fe 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1352,6 +1352,11 @@ bool TargetTransformInfo::hasActiveVectorLength(unsigned Opcode, Type *DataType,
   return TTIImpl->hasActiveVectorLength(Opcode, DataType, Alignment);
 }
 
+bool TargetTransformInfo::hasVectorMatch(VectorType *VT,
+                                         unsigned SegSize) const {
+  return TTIImpl->hasVectorMatch(VT, SegSize);
+}
+
 TargetTransformInfo::Concept::~Concept() = default;
 
 TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 9d617c7acd13c..9cb7d65975b9f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8096,6 +8096,15 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
              DAG.getNode(ISD::EXTRACT_SUBVECTOR, sdl, ResultVT, Vec, Index));
     return;
   }
+  case Intrinsic::experimental_vector_match: {
+    auto *VT = dyn_cast<VectorType>(I.getOperand(0)->getType());
+    auto SegmentSize = cast<ConstantInt>(I.getOperand(3))->getLimitedValue();
+    const auto &TTI =
+        TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
+    assert(VT && TTI.hasVectorMatch(VT, SegmentSize) && "Unsupported type!");
+    visitTargetIntrinsic(I, Intrinsic);
+    return;
+  }
   case Intrinsic::vector_reverse:
     visitVectorReverse(I);
     return;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7704321a0fc3a..050807142fc0a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -6106,6 +6106,51 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
         DAG.getNode(AArch64ISD::CTTZ_ELTS, dl, MVT::i64, CttzOp);
     return DAG.getZExtOrTrunc(NewCttzElts, dl, Op.getValueType());
   }
+  case Intrinsic::experimental_vector_match: {
+    SDValue ID =
+        DAG.getTargetConstant(Intrinsic::aarch64_sve_match, dl, MVT::i64);
+
+    auto Op1 = Op.getOperand(1);
+    auto Op2 = Op.getOperand(2);
+    auto Mask = Op.getOperand(3);
+    auto SegmentSize =
+        cast<ConstantSDNode>(Op.getOperand(4))->getLimitedValue();
+
+    EVT VT = Op.getValueType();
+    auto MinNumElts = VT.getVectorMinNumElements();
+
+    assert(Op1.getValueType() == Op2.getValueType() && "Type mismatch.");
+    assert(Op1.getValueSizeInBits().getKnownMinValue() == 128 &&
+           "Custom lower only works on 128-bit segments.");
+    assert((Op1.getValueType().getVectorElementType() == MVT::i8  ||
+            Op1.getValueType().getVectorElementType() == MVT::i16) &&
+           "Custom lower only supports 8-bit or 16-bit characters.");
+    assert(SegmentSize == MinNumElts && "Custom lower needs segment size to "
+                                        "match minimum number of elements.");
+
+    if (VT.isScalableVector())
+      return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, ID, Mask, Op1, Op2);
+
+    // We can use the SVE2 match instruction to lower this intrinsic by
+    // converting the operands to scalable vectors, doing a match, and then
+    // extracting a fixed-width subvector from the scalable vector.
+
+    EVT OpVT = Op1.getValueType();
+    EVT OpContainerVT = getContainerForFixedLengthVector(DAG, OpVT);
+    EVT MatchVT = OpContainerVT.changeElementType(MVT::i1);
+
+    auto ScalableOp1 = convertToScalableVector(DAG, OpContainerVT, Op1);
+    auto ScalableOp2 = convertToScalableVector(DAG, OpContainerVT, Op2);
+    auto ScalableMask = DAG.getNode(ISD::SIGN_EXTEND, dl, OpVT, Mask);
+    ScalableMask = convertFixedMaskToScalableVector(ScalableMask, DAG);
+
+    SDValue Match = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MatchVT, ID,
+                                ScalableMask, ScalableOp1, ScalableOp2);
+
+    return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT,
+                       DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match),
+                       DAG.getVectorIdxConstant(0, dl));
+  }
   }
 }
 
@@ -26544,6 +26589,7 @@ void AArch64TargetLowering::ReplaceNodeResults(
       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
       return;
     }
+    case Intrinsic::experimental_vector_match:
     case Intrinsic::get_active_lane_mask: {
       if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1)
         return;
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index b8f19fa87e2ab..806dc856c5862 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3835,6 +3835,18 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction(
   }
 }
 
+bool AArch64TTIImpl::hasVectorMatch(VectorType *VT, unsigned SegSize) const {
+  // Check that the target has SVE2 (and SVE is available), that `VT' is a
+  // legal type for MATCH, and that the segment size is 128-bit.
+  if (ST->hasSVE2() && ST->isSVEAvailable() &&
+      VT->getPrimitiveSizeInBits().getKnownMinValue() == 128 &&
+      VT->getElementCount().getKnownMinValue() == SegSize &&
+      (VT->getElementCount().getKnownMinValue() ==  8 ||
+       VT->getElementCount().getKnownMinValue() == 16))
+    return true;
+  return false;
+}
+
 InstructionCost
 AArch64TTIImpl::getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty,
                                        FastMathFlags FMF,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index a9189fd53f40b..6ad21a9e0a77a 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -391,6 +391,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
     return ST->hasSVE();
   }
 
+  bool hasVectorMatch(VectorType *VT, unsigned SegSize) const;
+
   InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
                                              std::optional<FastMathFlags> FMF,
                                              TTI::TargetCostKind CostKind);
diff --git a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
index cb31e2a2ecaec..a9683f08c5ab9 100644
--- a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
@@ -79,6 +79,12 @@ static cl::opt<unsigned>
               cl::desc("The vectorization factor for byte-compare patterns."),
               cl::init(16));
 
+static cl::opt<bool>
+    DisableFindFirstByte("disable-loop-idiom-vectorize-find-first-byte",
+                         cl::Hidden, cl::init(false),
+                         cl::desc("Proceed with Loop Idiom Vectorize Pass, but "
+                                  "do not convert find-first-byte loop(s)."));
+
 static cl::opt<bool>
     VerifyLoops("loop-idiom-vectorize-verify", cl::Hidden, cl::init(false),
                 cl::desc("Verify loops generated Loop Idiom Vectorize Pass."));
@@ -136,6 +142,21 @@ class LoopIdiomVectorize {
                             PHINode *IndPhi, Value *MaxLen, Instruction *Index,
                             Value *Start, bool IncIdx, BasicBlock *FoundBB,
                             BasicBlock *EndBB);
+
+  bool recognizeFindFirstByte();
+
+  Value *expandFindFirstByte(IRBuilder<> &Builder, DomTreeUpdater &DTU,
+                             unsigned VF, unsigned CharWidth,
+                             BasicBlock *ExitSucc, BasicBlock *ExitFail,
+                             GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
+                             Value *StartA, Value *EndA,
+                             Value *StartB, Value *EndB);
+
+  void transformFindFirstByte(PHINode *IndPhi, unsigned VF, unsigned CharWidth,
+                              BasicBlock *ExitSucc, BasicBlock *ExitFail,
+                              GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
+                              Value *StartA, Value *EndA,
+                              Value *StartB, Value *EndB);
   /// @}
 };
 } // anonymous namespace
@@ -190,7 +211,13 @@ bool LoopIdiomVectorize::run(Loop *L) {
   LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" << F.getName() << "] Loop %"
                     << CurLoop->getHeader()->getName() << "\n");
 
-  return recognizeByteCompare();
+  if (recognizeByteCompare())
+    return true;
+
+  if (recognizeFindFirstByte())
+    return true;
+
+  return false;
 }
 
 bool LoopIdiomVectorize::recognizeByteCompare() {
@@ -941,3 +968,416 @@ void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA,
       report_fatal_error("Loops must remain in LCSSA form!");
   }
 }
+
+bool LoopIdiomVectorize::recognizeFindFirstByte() {
+  // Currently the transformation only works on scalable vector types, although
+  // there is no fundamental reason why it cannot be made to work for fixed
+  // width too.
+  if (!TTI->supportsScalableVectors() || DisableFindFirstByte)
+    return false;
+
+  // Define some constants we need throughout.
+  // TODO: Some of these could be made configurable parameters. For example, we
+  // could allow CharWidth = 16 (and VF = 8).
+  unsigned VF = 16;
+  unsigned CharWidth = 8;
+  BasicBlock *Header = CurLoop->getHeader();
+  LLVMContext &Ctx = Header->getContext();
+  auto *CharTy = Type::getIntNTy(Ctx, CharWidth);
+  auto *CharVTy = ScalableVectorType::get(CharTy, VF);
+
+  // Check if the target supports efficient vector matches for vectors of
+  // bytes.
+  if (!TTI->hasVectorMatch(CharVTy, VF))
+    return false;
+
+  // In LoopIdiomVectorize::run we have already checked that the loop has a
+  // preheader so we can assume it's in a canonical form.
+  if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 4)
+    return false;
+
+  // We expect this loop to have one nested loop.
+  if (CurLoop->getSubLoops().size() != 1)
+    return false;
+
+  auto *InnerLoop = CurLoop->getSubLoops().front();
+  PHINode *IndPhi = dyn_cast<PHINode>(&Header->front());
+
+  if (!IndPhi || IndPhi->getNumIncomingValues() != 2)
+    return false;
+
+  auto LoopBlocks = CurLoop->getBlocks();
+  // We are expecting the following blocks below. For now, we will bail out for
+  // anything deviating from this.
+  //
+  // .preheader:                                       ; preds = %.preheader.preheader, %23
+  //   %14 = phi ptr [ %24, %23 ], [ %3, %.preheader.preheader ]
+  //   %15 = load i8, ptr %14, align 1, !tbaa !14
+  //   br label %19
+  //
+  // 19:                                               ; preds = %16, %.preheader
+  //   %20 = phi ptr [ %7, %.preheader ], [ %17, %16 ]
+  //   %21 = load i8, ptr %20, align 1, !tbaa !14
+  //   %22 = icmp eq i8 %15, %21
+  //   br i1 %22, label %.loopexit.loopexit, label %16
+  //
+  // 16:                                               ; preds = %19
+  //   %17 = getelementptr inbounds i8, ptr %20, i64 1
+  //   %18 = icmp eq ptr %17, %10
+  //   br i1 %18, label %23, label %19, !llvm.loop !15
+  //
+  // 23:                                               ; preds = %16
+  //   %24 = getelementptr inbounds i8, ptr %14, i64 1
+  //   %25 = icmp eq ptr %24, %6
+  //   br i1 %25, label %.loopexit.loopexit5, label %.preheader, !llvm.loop !17
+  //
+  if (LoopBlocks[0]->sizeWithoutDebug() > 3 ||
+      LoopBlocks[1]->sizeWithoutDebug() > 4 ||
+      LoopBlocks[2]->sizeWithoutDebug() > 3 ||
+      LoopBlocks[3]->sizeWithoutDebug() > 3)
+    return false;
+
+  // If we match the pattern, IndPhi is going to be replaced. We cannot replace
+  // the loop if any other of its instructions are used outside of it.
+  for (BasicBlock *BB : LoopBlocks)
+    for (Instruction &I : *BB)
+      if (&I != IndPhi)
+        for (User *U : I.users())
+          if (!CurLoop->contains(cast<Instruction>(U)))
+            return false;
+
+  // Match the branch instruction for the header. We are expecting an
+  // unconditional branch to the inner loop.
+  BasicBlock *MatchBB;
+  if (!match(Header->getTerminator(), m_UnconditionalBr(MatchBB)) ||
+      !InnerLoop->contains(MatchBB))
+    return false;
+
+  // MatchBB should be the entrypoint into the inner loop containing the
+  // comparison between a search item and a valid/successful match.
+  ICmpInst::Predicate MatchPred;
+  BasicBlock *ExitSucc;
+  BasicBlock *InnerBB;
+  Value *LoadA, *LoadB;
+  if (!match(MatchBB->getTerminator(),
+             m_Br(m_ICmp(MatchPred, m_Value(LoadA), m_Value(LoadB)),
+                  m_BasicBlock(ExitSucc), m_BasicBlock(InnerBB))) ||
+      MatchPred != ICmpInst::Predicate::ICMP_EQ ||
+      !InnerLoop->contains(InnerBB))
+    return false;
+
+  // We expect a single use of IndPhi outside of CurLoop. The outside use
+  // should be a PHINode in ExitSucc coming from MatchBB.
+  // Note: Strictly speaking we are not checking for a *single* use of IndPhi
+  // outside of CurLoop here, but below we check that we only exit CurLoop to
+  // ExitSucc in one place, so by construction this should be true. Besides, in
+  // the event it is not, as long as the use is a PHINode in ExitSucc and comes
+  // from MatchBB, the transformation should still be valid in any case.
+  for (Use &U : IndPhi->uses())
+    if (CurLoop->contains(cast<Instruction>(U.getUser())))
+      continue;
+    else if (auto *PN = dyn_cast<PHINode>(U.getUser());
+             !PN || PN->getParent() != ExitSucc ||
+             PN->getIncomingBlock(U) != MatchBB)
+      return false;
+
+  // Match the loads.
+  Value *A, *B;
+  if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B))))
+    return false;
+
+  // Make sure they are simple.
+  LoadInst *LoadAI = cast<LoadInst>(LoadA);
+  LoadInst *LoadBI = cast<LoadInst>(LoadB);
+  if (!LoadAI->isSimple() || !LoadBI->isSimple())
+    return false;
+
+  // The values loaded come from two PHIs that can only have two incoming
+  // values.
+  PHINode *PNA = dyn_cast<PHINode>(A);
+  PHINode *PNB = dyn_cast<PHINode>(B);
+  if (!PNA || PNA->getNumIncomingValues() != 2 ||
+      !PNB || PNB->getNumIncomingValues() != 2)
+    return false;
+
+  // One PHI comes from the outer loop, the other one from the inner loop.
+  // CurLoop contains PNA, InnerLoop PNB.
+  if (InnerLoop->contains(PNA))
+    std::swap(PNA, PNB);
+  if (PNA != &Header->front() || PNB != &MatchBB->front())
+    return false;
+
+  // The incoming values of both PHI nodes should be a gep of 1.
+  Value *StartA = PNA->getIncomingValue(0);
+  Value *IndexA = PNA->getIncomingValue(1);
+  if (CurLoop->contains(PNA->getIncomingBlock(0)))
+    std::swap(StartA, IndexA);...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 5, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Ricardo Jesus (rj-jesus)

Changes

This patch adds a new loop to LoopIdiomVectorizePass, enabling it to
recognise and vectorise loops such as:

template&lt;class InputIt, class ForwardIt&gt;
InputIt find_first_of(InputIt first, InputIt last,
                      ForwardIt s_first, ForwardIt s_last)
{
  for (; first != last; ++first)
    for (ForwardIt it = s_first; it != s_last; ++it)
      if (*first == *it)
        return first;
  return last;
}

These loops match the C++ standard library function std::find_first_of.

The loops are vectorised using @<!-- -->experimental.vector.match in #101974.


Patch is 41.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/101976.diff

12 Files Affected:

  • (modified) llvm/docs/LangRef.rst (+45)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+9)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+2)
  • (modified) llvm/include/llvm/IR/Intrinsics.td (+10)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+5)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+9)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+46)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+12)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+2)
  • (modified) llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp (+441-1)
  • (added) llvm/test/CodeGen/AArch64/find-first-byte.ll (+120)
  • (added) llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll (+57)
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index b17e3c828ed3d..dd9851d1af078 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -19637,6 +19637,51 @@ are undefined.
     }
 
 
+'``llvm.experimental.vector.match.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+This is an overloaded intrinsic. Support for specific vector types is target
+dependent.
+
+::
+
+    declare <<n> x i1> @llvm.experimental.vector.match(<<n> x <ty>> %op1, <<n> x <ty>> %op2, <<n> x i1> %mask, i32 <segsize>)
+    declare <vscale x <n> x i1> @llvm.experimental.vector.match(<vscale x <n> x <ty>> %op1, <vscale x <n> x <ty>> %op2, <vscale x <n> x i1> %mask, i32 <segsize>)
+
+Overview:
+"""""""""
+
+Find elements of the first argument matching any elements of the second.
+
+Arguments:
+""""""""""
+
+The first argument is the search vector, the second argument is the vector of
+elements we are searching for (i.e. for which we consider a match successful),
+and the third argument is a mask that controls which elements of the first
+argument are active. The fourth argument is an immediate that sets the segment
+size for the search window.
+
+Semantics:
+""""""""""
+
+The '``llvm.experimental.vector.match``' intrinsic compares each element in the
+first argument against potentially several elements of the second, placing
+``1`` in the corresponding element of the output vector if any comparison is
+successful, and ``0`` otherwise. Inactive elements in the mask are set to ``0``
+in the output. The segment size controls the number of elements of the second
+argument that are compared against.
+
+For example, for vectors with 16 elements, if ``segsize = 16`` then each
+element of the first argument is compared against all 16 elements of the second
+argument; but if ``segsize = 4``, then each of the first four elements of the
+first argument is compared against the first four elements of the second
+argument, each of the second four elements of the first argument is compared
+against the second four elements of the second argument, and so forth.
+
 Matrix Intrinsics
 -----------------
 
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 38e8b9da21397..786c13a177ccf 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1746,6 +1746,10 @@ class TargetTransformInfo {
   bool hasActiveVectorLength(unsigned Opcode, Type *DataType,
                              Align Alignment) const;
 
+  /// \returns Returns true if the target supports vector match operations for
+  /// the vector type `VT` using a segment size of `SegSize`.
+  bool hasVectorMatch(VectorType *VT, unsigned SegSize) const;
+
   struct VPLegalization {
     enum VPTransform {
       // keep the predicating parameter
@@ -2184,6 +2188,7 @@ class TargetTransformInfo::Concept {
   virtual bool supportsScalableVectors() const = 0;
   virtual bool hasActiveVectorLength(unsigned Opcode, Type *DataType,
                                      Align Alignment) const = 0;
+  virtual bool hasVectorMatch(VectorType *VT, unsigned SegSize) const = 0;
   virtual VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
   virtual bool hasArmWideBranch(bool Thumb) const = 0;
@@ -2952,6 +2957,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.hasActiveVectorLength(Opcode, DataType, Alignment);
   }
 
+  bool hasVectorMatch(VectorType *VT, unsigned SegSize) const override {
+    return Impl.hasVectorMatch(VT, SegSize);
+  }
+
   VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const override {
     return Impl.getVPLegalizationStrategy(PI);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index d208a710bb27f..36621861ab8c8 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -958,6 +958,8 @@ class TargetTransformInfoImplBase {
     return false;
   }
 
+  bool hasVectorMatch(VectorType *VT, unsigned SegSize) const { return false; }
+
   TargetTransformInfo::VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const {
     return TargetTransformInfo::VPLegalization(
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index b4e758136b39f..f6d77aa596f60 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1892,6 +1892,16 @@ def int_experimental_vector_histogram_add : DefaultAttrsIntrinsic<[],
                                LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], // Mask
                              [ IntrArgMemOnly ]>;
 
+// Experimental match
+def int_experimental_vector_match : DefaultAttrsIntrinsic<
+                             [ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ],
+                             [ llvm_anyvector_ty,
+                               LLVMMatchType<0>,
+                               LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,  // Mask
+                               llvm_i32_ty ],  // Segment size
+                             [ IntrNoMem, IntrNoSync, IntrWillReturn,
+                               ImmArg<ArgIndex<3>> ]>;
+
 // Operators
 let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn] in {
   // Integer arithmetic
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index dcde78925bfa9..d8314af0537fe 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1352,6 +1352,11 @@ bool TargetTransformInfo::hasActiveVectorLength(unsigned Opcode, Type *DataType,
   return TTIImpl->hasActiveVectorLength(Opcode, DataType, Alignment);
 }
 
+bool TargetTransformInfo::hasVectorMatch(VectorType *VT,
+                                         unsigned SegSize) const {
+  return TTIImpl->hasVectorMatch(VT, SegSize);
+}
+
 TargetTransformInfo::Concept::~Concept() = default;
 
 TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 9d617c7acd13c..9cb7d65975b9f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8096,6 +8096,15 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
              DAG.getNode(ISD::EXTRACT_SUBVECTOR, sdl, ResultVT, Vec, Index));
     return;
   }
+  case Intrinsic::experimental_vector_match: {
+    auto *VT = dyn_cast<VectorType>(I.getOperand(0)->getType());
+    auto SegmentSize = cast<ConstantInt>(I.getOperand(3))->getLimitedValue();
+    const auto &TTI =
+        TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
+    assert(VT && TTI.hasVectorMatch(VT, SegmentSize) && "Unsupported type!");
+    visitTargetIntrinsic(I, Intrinsic);
+    return;
+  }
   case Intrinsic::vector_reverse:
     visitVectorReverse(I);
     return;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7704321a0fc3a..050807142fc0a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -6106,6 +6106,51 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
         DAG.getNode(AArch64ISD::CTTZ_ELTS, dl, MVT::i64, CttzOp);
     return DAG.getZExtOrTrunc(NewCttzElts, dl, Op.getValueType());
   }
+  case Intrinsic::experimental_vector_match: {
+    SDValue ID =
+        DAG.getTargetConstant(Intrinsic::aarch64_sve_match, dl, MVT::i64);
+
+    auto Op1 = Op.getOperand(1);
+    auto Op2 = Op.getOperand(2);
+    auto Mask = Op.getOperand(3);
+    auto SegmentSize =
+        cast<ConstantSDNode>(Op.getOperand(4))->getLimitedValue();
+
+    EVT VT = Op.getValueType();
+    auto MinNumElts = VT.getVectorMinNumElements();
+
+    assert(Op1.getValueType() == Op2.getValueType() && "Type mismatch.");
+    assert(Op1.getValueSizeInBits().getKnownMinValue() == 128 &&
+           "Custom lower only works on 128-bit segments.");
+    assert((Op1.getValueType().getVectorElementType() == MVT::i8  ||
+            Op1.getValueType().getVectorElementType() == MVT::i16) &&
+           "Custom lower only supports 8-bit or 16-bit characters.");
+    assert(SegmentSize == MinNumElts && "Custom lower needs segment size to "
+                                        "match minimum number of elements.");
+
+    if (VT.isScalableVector())
+      return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, ID, Mask, Op1, Op2);
+
+    // We can use the SVE2 match instruction to lower this intrinsic by
+    // converting the operands to scalable vectors, doing a match, and then
+    // extracting a fixed-width subvector from the scalable vector.
+
+    EVT OpVT = Op1.getValueType();
+    EVT OpContainerVT = getContainerForFixedLengthVector(DAG, OpVT);
+    EVT MatchVT = OpContainerVT.changeElementType(MVT::i1);
+
+    auto ScalableOp1 = convertToScalableVector(DAG, OpContainerVT, Op1);
+    auto ScalableOp2 = convertToScalableVector(DAG, OpContainerVT, Op2);
+    auto ScalableMask = DAG.getNode(ISD::SIGN_EXTEND, dl, OpVT, Mask);
+    ScalableMask = convertFixedMaskToScalableVector(ScalableMask, DAG);
+
+    SDValue Match = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MatchVT, ID,
+                                ScalableMask, ScalableOp1, ScalableOp2);
+
+    return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT,
+                       DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match),
+                       DAG.getVectorIdxConstant(0, dl));
+  }
   }
 }
 
@@ -26544,6 +26589,7 @@ void AArch64TargetLowering::ReplaceNodeResults(
       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
       return;
     }
+    case Intrinsic::experimental_vector_match:
     case Intrinsic::get_active_lane_mask: {
       if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1)
         return;
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index b8f19fa87e2ab..806dc856c5862 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3835,6 +3835,18 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction(
   }
 }
 
+bool AArch64TTIImpl::hasVectorMatch(VectorType *VT, unsigned SegSize) const {
+  // Check that the target has SVE2 (and SVE is available), that `VT' is a
+  // legal type for MATCH, and that the segment size is 128-bit.
+  if (ST->hasSVE2() && ST->isSVEAvailable() &&
+      VT->getPrimitiveSizeInBits().getKnownMinValue() == 128 &&
+      VT->getElementCount().getKnownMinValue() == SegSize &&
+      (VT->getElementCount().getKnownMinValue() ==  8 ||
+       VT->getElementCount().getKnownMinValue() == 16))
+    return true;
+  return false;
+}
+
 InstructionCost
 AArch64TTIImpl::getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty,
                                        FastMathFlags FMF,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index a9189fd53f40b..6ad21a9e0a77a 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -391,6 +391,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
     return ST->hasSVE();
   }
 
+  bool hasVectorMatch(VectorType *VT, unsigned SegSize) const;
+
   InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
                                              std::optional<FastMathFlags> FMF,
                                              TTI::TargetCostKind CostKind);
diff --git a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
index cb31e2a2ecaec..a9683f08c5ab9 100644
--- a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
@@ -79,6 +79,12 @@ static cl::opt<unsigned>
               cl::desc("The vectorization factor for byte-compare patterns."),
               cl::init(16));
 
+static cl::opt<bool>
+    DisableFindFirstByte("disable-loop-idiom-vectorize-find-first-byte",
+                         cl::Hidden, cl::init(false),
+                         cl::desc("Proceed with Loop Idiom Vectorize Pass, but "
+                                  "do not convert find-first-byte loop(s)."));
+
 static cl::opt<bool>
     VerifyLoops("loop-idiom-vectorize-verify", cl::Hidden, cl::init(false),
                 cl::desc("Verify loops generated Loop Idiom Vectorize Pass."));
@@ -136,6 +142,21 @@ class LoopIdiomVectorize {
                             PHINode *IndPhi, Value *MaxLen, Instruction *Index,
                             Value *Start, bool IncIdx, BasicBlock *FoundBB,
                             BasicBlock *EndBB);
+
+  bool recognizeFindFirstByte();
+
+  Value *expandFindFirstByte(IRBuilder<> &Builder, DomTreeUpdater &DTU,
+                             unsigned VF, unsigned CharWidth,
+                             BasicBlock *ExitSucc, BasicBlock *ExitFail,
+                             GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
+                             Value *StartA, Value *EndA,
+                             Value *StartB, Value *EndB);
+
+  void transformFindFirstByte(PHINode *IndPhi, unsigned VF, unsigned CharWidth,
+                              BasicBlock *ExitSucc, BasicBlock *ExitFail,
+                              GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
+                              Value *StartA, Value *EndA,
+                              Value *StartB, Value *EndB);
   /// @}
 };
 } // anonymous namespace
@@ -190,7 +211,13 @@ bool LoopIdiomVectorize::run(Loop *L) {
   LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" << F.getName() << "] Loop %"
                     << CurLoop->getHeader()->getName() << "\n");
 
-  return recognizeByteCompare();
+  if (recognizeByteCompare())
+    return true;
+
+  if (recognizeFindFirstByte())
+    return true;
+
+  return false;
 }
 
 bool LoopIdiomVectorize::recognizeByteCompare() {
@@ -941,3 +968,416 @@ void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA,
       report_fatal_error("Loops must remain in LCSSA form!");
   }
 }
+
+bool LoopIdiomVectorize::recognizeFindFirstByte() {
+  // Currently the transformation only works on scalable vector types, although
+  // there is no fundamental reason why it cannot be made to work for fixed
+  // width too.
+  if (!TTI->supportsScalableVectors() || DisableFindFirstByte)
+    return false;
+
+  // Define some constants we need throughout.
+  // TODO: Some of these could be made configurable parameters. For example, we
+  // could allow CharWidth = 16 (and VF = 8).
+  unsigned VF = 16;
+  unsigned CharWidth = 8;
+  BasicBlock *Header = CurLoop->getHeader();
+  LLVMContext &Ctx = Header->getContext();
+  auto *CharTy = Type::getIntNTy(Ctx, CharWidth);
+  auto *CharVTy = ScalableVectorType::get(CharTy, VF);
+
+  // Check if the target supports efficient vector matches for vectors of
+  // bytes.
+  if (!TTI->hasVectorMatch(CharVTy, VF))
+    return false;
+
+  // In LoopIdiomVectorize::run we have already checked that the loop has a
+  // preheader so we can assume it's in a canonical form.
+  if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 4)
+    return false;
+
+  // We expect this loop to have one nested loop.
+  if (CurLoop->getSubLoops().size() != 1)
+    return false;
+
+  auto *InnerLoop = CurLoop->getSubLoops().front();
+  PHINode *IndPhi = dyn_cast<PHINode>(&Header->front());
+
+  if (!IndPhi || IndPhi->getNumIncomingValues() != 2)
+    return false;
+
+  auto LoopBlocks = CurLoop->getBlocks();
+  // We are expecting the following blocks below. For now, we will bail out for
+  // anything deviating from this.
+  //
+  // .preheader:                                       ; preds = %.preheader.preheader, %23
+  //   %14 = phi ptr [ %24, %23 ], [ %3, %.preheader.preheader ]
+  //   %15 = load i8, ptr %14, align 1, !tbaa !14
+  //   br label %19
+  //
+  // 19:                                               ; preds = %16, %.preheader
+  //   %20 = phi ptr [ %7, %.preheader ], [ %17, %16 ]
+  //   %21 = load i8, ptr %20, align 1, !tbaa !14
+  //   %22 = icmp eq i8 %15, %21
+  //   br i1 %22, label %.loopexit.loopexit, label %16
+  //
+  // 16:                                               ; preds = %19
+  //   %17 = getelementptr inbounds i8, ptr %20, i64 1
+  //   %18 = icmp eq ptr %17, %10
+  //   br i1 %18, label %23, label %19, !llvm.loop !15
+  //
+  // 23:                                               ; preds = %16
+  //   %24 = getelementptr inbounds i8, ptr %14, i64 1
+  //   %25 = icmp eq ptr %24, %6
+  //   br i1 %25, label %.loopexit.loopexit5, label %.preheader, !llvm.loop !17
+  //
+  if (LoopBlocks[0]->sizeWithoutDebug() > 3 ||
+      LoopBlocks[1]->sizeWithoutDebug() > 4 ||
+      LoopBlocks[2]->sizeWithoutDebug() > 3 ||
+      LoopBlocks[3]->sizeWithoutDebug() > 3)
+    return false;
+
+  // If we match the pattern, IndPhi is going to be replaced. We cannot replace
+  // the loop if any other of its instructions are used outside of it.
+  for (BasicBlock *BB : LoopBlocks)
+    for (Instruction &I : *BB)
+      if (&I != IndPhi)
+        for (User *U : I.users())
+          if (!CurLoop->contains(cast<Instruction>(U)))
+            return false;
+
+  // Match the branch instruction for the header. We are expecting an
+  // unconditional branch to the inner loop.
+  BasicBlock *MatchBB;
+  if (!match(Header->getTerminator(), m_UnconditionalBr(MatchBB)) ||
+      !InnerLoop->contains(MatchBB))
+    return false;
+
+  // MatchBB should be the entrypoint into the inner loop containing the
+  // comparison between a search item and a valid/successful match.
+  ICmpInst::Predicate MatchPred;
+  BasicBlock *ExitSucc;
+  BasicBlock *InnerBB;
+  Value *LoadA, *LoadB;
+  if (!match(MatchBB->getTerminator(),
+             m_Br(m_ICmp(MatchPred, m_Value(LoadA), m_Value(LoadB)),
+                  m_BasicBlock(ExitSucc), m_BasicBlock(InnerBB))) ||
+      MatchPred != ICmpInst::Predicate::ICMP_EQ ||
+      !InnerLoop->contains(InnerBB))
+    return false;
+
+  // We expect a single use of IndPhi outside of CurLoop. The outside use
+  // should be a PHINode in ExitSucc coming from MatchBB.
+  // Note: Strictly speaking we are not checking for a *single* use of IndPhi
+  // outside of CurLoop here, but below we check that we only exit CurLoop to
+  // ExitSucc in one place, so by construction this should be true. Besides, in
+  // the event it is not, as long as the use is a PHINode in ExitSucc and comes
+  // from MatchBB, the transformation should still be valid in any case.
+  for (Use &U : IndPhi->uses())
+    if (CurLoop->contains(cast<Instruction>(U.getUser())))
+      continue;
+    else if (auto *PN = dyn_cast<PHINode>(U.getUser());
+             !PN || PN->getParent() != ExitSucc ||
+             PN->getIncomingBlock(U) != MatchBB)
+      return false;
+
+  // Match the loads.
+  Value *A, *B;
+  if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B))))
+    return false;
+
+  // Make sure they are simple.
+  LoadInst *LoadAI = cast<LoadInst>(LoadA);
+  LoadInst *LoadBI = cast<LoadInst>(LoadB);
+  if (!LoadAI->isSimple() || !LoadBI->isSimple())
+    return false;
+
+  // The values loaded come from two PHIs that can only have two incoming
+  // values.
+  PHINode *PNA = dyn_cast<PHINode>(A);
+  PHINode *PNB = dyn_cast<PHINode>(B);
+  if (!PNA || PNA->getNumIncomingValues() != 2 ||
+      !PNB || PNB->getNumIncomingValues() != 2)
+    return false;
+
+  // One PHI comes from the outer loop, the other one from the inner loop.
+  // CurLoop contains PNA, InnerLoop PNB.
+  if (InnerLoop->contains(PNA))
+    std::swap(PNA, PNB);
+  if (PNA != &Header->front() || PNB != &MatchBB->front())
+    return false;
+
+  // The incoming values of both PHI nodes should be a gep of 1.
+  Value *StartA = PNA->getIncomingValue(0);
+  Value *IndexA = PNA->getIncomingValue(1);
+  if (CurLoop->contains(PNA->getIncomingBlock(0)))
+    std::swap(StartA, IndexA);...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 5, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Ricardo Jesus (rj-jesus)

Changes

This patch adds a new loop to LoopIdiomVectorizePass, enabling it to
recognise and vectorise loops such as:

template&lt;class InputIt, class ForwardIt&gt;
InputIt find_first_of(InputIt first, InputIt last,
                      ForwardIt s_first, ForwardIt s_last)
{
  for (; first != last; ++first)
    for (ForwardIt it = s_first; it != s_last; ++it)
      if (*first == *it)
        return first;
  return last;
}

These loops match the C++ standard library function std::find_first_of.

The loops are vectorised using @<!-- -->experimental.vector.match in #101974.


Patch is 41.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/101976.diff

12 Files Affected:

  • (modified) llvm/docs/LangRef.rst (+45)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+9)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+2)
  • (modified) llvm/include/llvm/IR/Intrinsics.td (+10)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+5)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+9)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+46)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+12)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+2)
  • (modified) llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp (+441-1)
  • (added) llvm/test/CodeGen/AArch64/find-first-byte.ll (+120)
  • (added) llvm/test/CodeGen/AArch64/intrinsic-vector-match-sve2.ll (+57)
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index b17e3c828ed3d..dd9851d1af078 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -19637,6 +19637,51 @@ are undefined.
     }
 
 
+'``llvm.experimental.vector.match.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+This is an overloaded intrinsic. Support for specific vector types is target
+dependent.
+
+::
+
+    declare <<n> x i1> @llvm.experimental.vector.match(<<n> x <ty>> %op1, <<n> x <ty>> %op2, <<n> x i1> %mask, i32 <segsize>)
+    declare <vscale x <n> x i1> @llvm.experimental.vector.match(<vscale x <n> x <ty>> %op1, <vscale x <n> x <ty>> %op2, <vscale x <n> x i1> %mask, i32 <segsize>)
+
+Overview:
+"""""""""
+
+Find elements of the first argument matching any elements of the second.
+
+Arguments:
+""""""""""
+
+The first argument is the search vector, the second argument is the vector of
+elements we are searching for (i.e. for which we consider a match successful),
+and the third argument is a mask that controls which elements of the first
+argument are active. The fourth argument is an immediate that sets the segment
+size for the search window.
+
+Semantics:
+""""""""""
+
+The '``llvm.experimental.vector.match``' intrinsic compares each element in the
+first argument against potentially several elements of the second, placing
+``1`` in the corresponding element of the output vector if any comparison is
+successful, and ``0`` otherwise. Inactive elements in the mask are set to ``0``
+in the output. The segment size controls the number of elements of the second
+argument that are compared against.
+
+For example, for vectors with 16 elements, if ``segsize = 16`` then each
+element of the first argument is compared against all 16 elements of the second
+argument; but if ``segsize = 4``, then each of the first four elements of the
+first argument is compared against the first four elements of the second
+argument, each of the second four elements of the first argument is compared
+against the second four elements of the second argument, and so forth.
+
 Matrix Intrinsics
 -----------------
 
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 38e8b9da21397..786c13a177ccf 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1746,6 +1746,10 @@ class TargetTransformInfo {
   bool hasActiveVectorLength(unsigned Opcode, Type *DataType,
                              Align Alignment) const;
 
+  /// \returns Returns true if the target supports vector match operations for
+  /// the vector type `VT` using a segment size of `SegSize`.
+  bool hasVectorMatch(VectorType *VT, unsigned SegSize) const;
+
   struct VPLegalization {
     enum VPTransform {
       // keep the predicating parameter
@@ -2184,6 +2188,7 @@ class TargetTransformInfo::Concept {
   virtual bool supportsScalableVectors() const = 0;
   virtual bool hasActiveVectorLength(unsigned Opcode, Type *DataType,
                                      Align Alignment) const = 0;
+  virtual bool hasVectorMatch(VectorType *VT, unsigned SegSize) const = 0;
   virtual VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
   virtual bool hasArmWideBranch(bool Thumb) const = 0;
@@ -2952,6 +2957,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.hasActiveVectorLength(Opcode, DataType, Alignment);
   }
 
+  bool hasVectorMatch(VectorType *VT, unsigned SegSize) const override {
+    return Impl.hasVectorMatch(VT, SegSize);
+  }
+
   VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const override {
     return Impl.getVPLegalizationStrategy(PI);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index d208a710bb27f..36621861ab8c8 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -958,6 +958,8 @@ class TargetTransformInfoImplBase {
     return false;
   }
 
+  bool hasVectorMatch(VectorType *VT, unsigned SegSize) const { return false; }
+
   TargetTransformInfo::VPLegalization
   getVPLegalizationStrategy(const VPIntrinsic &PI) const {
     return TargetTransformInfo::VPLegalization(
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index b4e758136b39f..f6d77aa596f60 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1892,6 +1892,16 @@ def int_experimental_vector_histogram_add : DefaultAttrsIntrinsic<[],
                                LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], // Mask
                              [ IntrArgMemOnly ]>;
 
+// Experimental match
+def int_experimental_vector_match : DefaultAttrsIntrinsic<
+                             [ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ],
+                             [ llvm_anyvector_ty,
+                               LLVMMatchType<0>,
+                               LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,  // Mask
+                               llvm_i32_ty ],  // Segment size
+                             [ IntrNoMem, IntrNoSync, IntrWillReturn,
+                               ImmArg<ArgIndex<3>> ]>;
+
 // Operators
 let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn] in {
   // Integer arithmetic
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index dcde78925bfa9..d8314af0537fe 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1352,6 +1352,11 @@ bool TargetTransformInfo::hasActiveVectorLength(unsigned Opcode, Type *DataType,
   return TTIImpl->hasActiveVectorLength(Opcode, DataType, Alignment);
 }
 
+bool TargetTransformInfo::hasVectorMatch(VectorType *VT,
+                                         unsigned SegSize) const {
+  return TTIImpl->hasVectorMatch(VT, SegSize);
+}
+
 TargetTransformInfo::Concept::~Concept() = default;
 
 TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 9d617c7acd13c..9cb7d65975b9f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8096,6 +8096,15 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
              DAG.getNode(ISD::EXTRACT_SUBVECTOR, sdl, ResultVT, Vec, Index));
     return;
   }
+  case Intrinsic::experimental_vector_match: {
+    auto *VT = dyn_cast<VectorType>(I.getOperand(0)->getType());
+    auto SegmentSize = cast<ConstantInt>(I.getOperand(3))->getLimitedValue();
+    const auto &TTI =
+        TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
+    assert(VT && TTI.hasVectorMatch(VT, SegmentSize) && "Unsupported type!");
+    visitTargetIntrinsic(I, Intrinsic);
+    return;
+  }
   case Intrinsic::vector_reverse:
     visitVectorReverse(I);
     return;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7704321a0fc3a..050807142fc0a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -6106,6 +6106,51 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
         DAG.getNode(AArch64ISD::CTTZ_ELTS, dl, MVT::i64, CttzOp);
     return DAG.getZExtOrTrunc(NewCttzElts, dl, Op.getValueType());
   }
+  case Intrinsic::experimental_vector_match: {
+    SDValue ID =
+        DAG.getTargetConstant(Intrinsic::aarch64_sve_match, dl, MVT::i64);
+
+    auto Op1 = Op.getOperand(1);
+    auto Op2 = Op.getOperand(2);
+    auto Mask = Op.getOperand(3);
+    auto SegmentSize =
+        cast<ConstantSDNode>(Op.getOperand(4))->getLimitedValue();
+
+    EVT VT = Op.getValueType();
+    auto MinNumElts = VT.getVectorMinNumElements();
+
+    assert(Op1.getValueType() == Op2.getValueType() && "Type mismatch.");
+    assert(Op1.getValueSizeInBits().getKnownMinValue() == 128 &&
+           "Custom lower only works on 128-bit segments.");
+    assert((Op1.getValueType().getVectorElementType() == MVT::i8  ||
+            Op1.getValueType().getVectorElementType() == MVT::i16) &&
+           "Custom lower only supports 8-bit or 16-bit characters.");
+    assert(SegmentSize == MinNumElts && "Custom lower needs segment size to "
+                                        "match minimum number of elements.");
+
+    if (VT.isScalableVector())
+      return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, ID, Mask, Op1, Op2);
+
+    // We can use the SVE2 match instruction to lower this intrinsic by
+    // converting the operands to scalable vectors, doing a match, and then
+    // extracting a fixed-width subvector from the scalable vector.
+
+    EVT OpVT = Op1.getValueType();
+    EVT OpContainerVT = getContainerForFixedLengthVector(DAG, OpVT);
+    EVT MatchVT = OpContainerVT.changeElementType(MVT::i1);
+
+    auto ScalableOp1 = convertToScalableVector(DAG, OpContainerVT, Op1);
+    auto ScalableOp2 = convertToScalableVector(DAG, OpContainerVT, Op2);
+    auto ScalableMask = DAG.getNode(ISD::SIGN_EXTEND, dl, OpVT, Mask);
+    ScalableMask = convertFixedMaskToScalableVector(ScalableMask, DAG);
+
+    SDValue Match = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MatchVT, ID,
+                                ScalableMask, ScalableOp1, ScalableOp2);
+
+    return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT,
+                       DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match),
+                       DAG.getVectorIdxConstant(0, dl));
+  }
   }
 }
 
@@ -26544,6 +26589,7 @@ void AArch64TargetLowering::ReplaceNodeResults(
       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
       return;
     }
+    case Intrinsic::experimental_vector_match:
     case Intrinsic::get_active_lane_mask: {
       if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1)
         return;
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index b8f19fa87e2ab..806dc856c5862 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3835,6 +3835,18 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction(
   }
 }
 
+bool AArch64TTIImpl::hasVectorMatch(VectorType *VT, unsigned SegSize) const {
+  // Check that the target has SVE2 (and SVE is available), that `VT' is a
+  // legal type for MATCH, and that the segment size is 128-bit.
+  if (ST->hasSVE2() && ST->isSVEAvailable() &&
+      VT->getPrimitiveSizeInBits().getKnownMinValue() == 128 &&
+      VT->getElementCount().getKnownMinValue() == SegSize &&
+      (VT->getElementCount().getKnownMinValue() ==  8 ||
+       VT->getElementCount().getKnownMinValue() == 16))
+    return true;
+  return false;
+}
+
 InstructionCost
 AArch64TTIImpl::getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty,
                                        FastMathFlags FMF,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index a9189fd53f40b..6ad21a9e0a77a 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -391,6 +391,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
     return ST->hasSVE();
   }
 
+  bool hasVectorMatch(VectorType *VT, unsigned SegSize) const;
+
   InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
                                              std::optional<FastMathFlags> FMF,
                                              TTI::TargetCostKind CostKind);
diff --git a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
index cb31e2a2ecaec..a9683f08c5ab9 100644
--- a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
@@ -79,6 +79,12 @@ static cl::opt<unsigned>
               cl::desc("The vectorization factor for byte-compare patterns."),
               cl::init(16));
 
+static cl::opt<bool>
+    DisableFindFirstByte("disable-loop-idiom-vectorize-find-first-byte",
+                         cl::Hidden, cl::init(false),
+                         cl::desc("Proceed with Loop Idiom Vectorize Pass, but "
+                                  "do not convert find-first-byte loop(s)."));
+
 static cl::opt<bool>
     VerifyLoops("loop-idiom-vectorize-verify", cl::Hidden, cl::init(false),
                 cl::desc("Verify loops generated Loop Idiom Vectorize Pass."));
@@ -136,6 +142,21 @@ class LoopIdiomVectorize {
                             PHINode *IndPhi, Value *MaxLen, Instruction *Index,
                             Value *Start, bool IncIdx, BasicBlock *FoundBB,
                             BasicBlock *EndBB);
+
+  bool recognizeFindFirstByte();
+
+  Value *expandFindFirstByte(IRBuilder<> &Builder, DomTreeUpdater &DTU,
+                             unsigned VF, unsigned CharWidth,
+                             BasicBlock *ExitSucc, BasicBlock *ExitFail,
+                             GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
+                             Value *StartA, Value *EndA,
+                             Value *StartB, Value *EndB);
+
+  void transformFindFirstByte(PHINode *IndPhi, unsigned VF, unsigned CharWidth,
+                              BasicBlock *ExitSucc, BasicBlock *ExitFail,
+                              GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
+                              Value *StartA, Value *EndA,
+                              Value *StartB, Value *EndB);
   /// @}
 };
 } // anonymous namespace
@@ -190,7 +211,13 @@ bool LoopIdiomVectorize::run(Loop *L) {
   LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" << F.getName() << "] Loop %"
                     << CurLoop->getHeader()->getName() << "\n");
 
-  return recognizeByteCompare();
+  if (recognizeByteCompare())
+    return true;
+
+  if (recognizeFindFirstByte())
+    return true;
+
+  return false;
 }
 
 bool LoopIdiomVectorize::recognizeByteCompare() {
@@ -941,3 +968,416 @@ void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA,
       report_fatal_error("Loops must remain in LCSSA form!");
   }
 }
+
+bool LoopIdiomVectorize::recognizeFindFirstByte() {
+  // Currently the transformation only works on scalable vector types, although
+  // there is no fundamental reason why it cannot be made to work for fixed
+  // width too.
+  if (!TTI->supportsScalableVectors() || DisableFindFirstByte)
+    return false;
+
+  // Define some constants we need throughout.
+  // TODO: Some of these could be made configurable parameters. For example, we
+  // could allow CharWidth = 16 (and VF = 8).
+  unsigned VF = 16;
+  unsigned CharWidth = 8;
+  BasicBlock *Header = CurLoop->getHeader();
+  LLVMContext &Ctx = Header->getContext();
+  auto *CharTy = Type::getIntNTy(Ctx, CharWidth);
+  auto *CharVTy = ScalableVectorType::get(CharTy, VF);
+
+  // Check if the target supports efficient vector matches for vectors of
+  // bytes.
+  if (!TTI->hasVectorMatch(CharVTy, VF))
+    return false;
+
+  // In LoopIdiomVectorize::run we have already checked that the loop has a
+  // preheader so we can assume it's in a canonical form.
+  if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 4)
+    return false;
+
+  // We expect this loop to have one nested loop.
+  if (CurLoop->getSubLoops().size() != 1)
+    return false;
+
+  auto *InnerLoop = CurLoop->getSubLoops().front();
+  PHINode *IndPhi = dyn_cast<PHINode>(&Header->front());
+
+  if (!IndPhi || IndPhi->getNumIncomingValues() != 2)
+    return false;
+
+  auto LoopBlocks = CurLoop->getBlocks();
+  // We are expecting the following blocks below. For now, we will bail out for
+  // anything deviating from this.
+  //
+  // .preheader:                                       ; preds = %.preheader.preheader, %23
+  //   %14 = phi ptr [ %24, %23 ], [ %3, %.preheader.preheader ]
+  //   %15 = load i8, ptr %14, align 1, !tbaa !14
+  //   br label %19
+  //
+  // 19:                                               ; preds = %16, %.preheader
+  //   %20 = phi ptr [ %7, %.preheader ], [ %17, %16 ]
+  //   %21 = load i8, ptr %20, align 1, !tbaa !14
+  //   %22 = icmp eq i8 %15, %21
+  //   br i1 %22, label %.loopexit.loopexit, label %16
+  //
+  // 16:                                               ; preds = %19
+  //   %17 = getelementptr inbounds i8, ptr %20, i64 1
+  //   %18 = icmp eq ptr %17, %10
+  //   br i1 %18, label %23, label %19, !llvm.loop !15
+  //
+  // 23:                                               ; preds = %16
+  //   %24 = getelementptr inbounds i8, ptr %14, i64 1
+  //   %25 = icmp eq ptr %24, %6
+  //   br i1 %25, label %.loopexit.loopexit5, label %.preheader, !llvm.loop !17
+  //
+  if (LoopBlocks[0]->sizeWithoutDebug() > 3 ||
+      LoopBlocks[1]->sizeWithoutDebug() > 4 ||
+      LoopBlocks[2]->sizeWithoutDebug() > 3 ||
+      LoopBlocks[3]->sizeWithoutDebug() > 3)
+    return false;
+
+  // If we match the pattern, IndPhi is going to be replaced. We cannot replace
+  // the loop if any other of its instructions are used outside of it.
+  for (BasicBlock *BB : LoopBlocks)
+    for (Instruction &I : *BB)
+      if (&I != IndPhi)
+        for (User *U : I.users())
+          if (!CurLoop->contains(cast<Instruction>(U)))
+            return false;
+
+  // Match the branch instruction for the header. We are expecting an
+  // unconditional branch to the inner loop.
+  BasicBlock *MatchBB;
+  if (!match(Header->getTerminator(), m_UnconditionalBr(MatchBB)) ||
+      !InnerLoop->contains(MatchBB))
+    return false;
+
+  // MatchBB should be the entrypoint into the inner loop containing the
+  // comparison between a search item and a valid/successful match.
+  ICmpInst::Predicate MatchPred;
+  BasicBlock *ExitSucc;
+  BasicBlock *InnerBB;
+  Value *LoadA, *LoadB;
+  if (!match(MatchBB->getTerminator(),
+             m_Br(m_ICmp(MatchPred, m_Value(LoadA), m_Value(LoadB)),
+                  m_BasicBlock(ExitSucc), m_BasicBlock(InnerBB))) ||
+      MatchPred != ICmpInst::Predicate::ICMP_EQ ||
+      !InnerLoop->contains(InnerBB))
+    return false;
+
+  // We expect a single use of IndPhi outside of CurLoop. The outside use
+  // should be a PHINode in ExitSucc coming from MatchBB.
+  // Note: Strictly speaking we are not checking for a *single* use of IndPhi
+  // outside of CurLoop here, but below we check that we only exit CurLoop to
+  // ExitSucc in one place, so by construction this should be true. Besides, in
+  // the event it is not, as long as the use is a PHINode in ExitSucc and comes
+  // from MatchBB, the transformation should still be valid in any case.
+  for (Use &U : IndPhi->uses())
+    if (CurLoop->contains(cast<Instruction>(U.getUser())))
+      continue;
+    else if (auto *PN = dyn_cast<PHINode>(U.getUser());
+             !PN || PN->getParent() != ExitSucc ||
+             PN->getIncomingBlock(U) != MatchBB)
+      return false;
+
+  // Match the loads.
+  Value *A, *B;
+  if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B))))
+    return false;
+
+  // Make sure they are simple.
+  LoadInst *LoadAI = cast<LoadInst>(LoadA);
+  LoadInst *LoadBI = cast<LoadInst>(LoadB);
+  if (!LoadAI->isSimple() || !LoadBI->isSimple())
+    return false;
+
+  // The values loaded come from two PHIs that can only have two incoming
+  // values.
+  PHINode *PNA = dyn_cast<PHINode>(A);
+  PHINode *PNB = dyn_cast<PHINode>(B);
+  if (!PNA || PNA->getNumIncomingValues() != 2 ||
+      !PNB || PNB->getNumIncomingValues() != 2)
+    return false;
+
+  // One PHI comes from the outer loop, the other one from the inner loop.
+  // CurLoop contains PNA, InnerLoop PNB.
+  if (InnerLoop->contains(PNA))
+    std::swap(PNA, PNB);
+  if (PNA != &Header->front() || PNB != &MatchBB->front())
+    return false;
+
+  // The incoming values of both PHI nodes should be a gep of 1.
+  Value *StartA = PNA->getIncomingValue(0);
+  Value *IndexA = PNA->getIncomingValue(1);
+  if (CurLoop->contains(PNA->getIncomingBlock(0)))
+    std::swap(StartA, IndexA);...
[truncated]

Copy link

github-actions bot commented Aug 5, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@rj-jesus
Copy link
Contributor Author

Hi @david-arm and @paulwalker-arm, I've rebased this patch to use the version of @llvm.experimental.vector.match committed last week. I've also added a simple cost model which currently only lets the transformation kick in for supported SVE vectors.

Currently, a simple double find loop like the one in the description should result in something like:

find_first_of_i8:                       // @find_first_of_i8
	.cfi_startproc
// %bb.0:
	cmp	x0, x1
	b.eq	.LBB0_7
// %bb.1:
	cmp	x2, x3
	b.eq	.LBB0_7
// %bb.2:                               // %.preheader
	ptrue	p0.b, vl16
.LBB0_3:                                // =>This Loop Header: Depth=1
                                        //     Child Loop BB0_4 Depth 2
	whilelo	p1.b, x0, x1
	mov	x8, x2
	mov	x9, x2
	and	p1.b, p0/z, p0.b, p1.b
	ld1b	{ z0.b }, p1/z, [x0]
.LBB0_4:                                //   Parent Loop BB0_3 Depth=1
                                        // =>  This Inner Loop Header: Depth=2
	whilelo	p2.b, x8, x3
	and	p2.b, p0/z, p0.b, p2.b
	ld1b	{ z1.b }, p2/z, [x9]
	mov	z2.b, b1
	sel	z1.b, p2, z1.b, z2.b
	mov	z1.q, q1
	match	p2.b, p1/z, z0.b, z1.b
	b.ne	.LBB0_8
// %bb.5:                               //   in Loop: Header=BB0_4 Depth=2
	add	x9, x9, #16
	add	x8, x8, #16
	cmp	x9, x3
	b.lo	.LBB0_4
// %bb.6:                               //   in Loop: Header=BB0_3 Depth=1
	add	x0, x0, #16
	cmp	x0, x1
	b.lo	.LBB0_3
.LBB0_7:                                // %.loopexit1
	mov	x0, x1
	ret
.LBB0_8:                                // %.loopexit
	ptrue	p0.b
	brkb	p0.b, p0/z, p2.b
	incp	x0, p0.b
	ret

Could you please let me know if you have any thoughts or suggestions? Many thanks in advance!

@sjoerdmeijer
Copy link
Collaborator

The output looks cool. Naive drive-by question: is it faster than what we currently generate? :)

@rj-jesus
Copy link
Contributor Author

rj-jesus commented Nov 21, 2024

It is, in the table below I've summarised the speedups obtained on a Neoverse V2 for a combination of search (columns) and needle (rows) vector sizes for the loop in the PR's description:
data
For item (4, 80) in the matrix, for example, we have a speedup of 16.1x, which means that, for search arrays of 80 characters and needle arrays of 4 characters, where the match happens at the last element of both arrays, we get an average speedup of 16.1 over the current scalar loops. I can probably share the benchmark I used to collect these numbers if you think that would be useful, I would just need to double-check this internally.

Overall the vectorised loops are at least competitive with the scalar loops (and often obtain substantial speedups over the latter). The exceptions are for degenerate cases of search or needle arrays consisting of a single element (top left corner of the matrix).

The output we currently generate for the vectorised loops also contains four "unnecessary instructions", in particular:

  • mov x8, x2 - redundant copy of mov x9, x2
  • add x8, x8, #16 - as above, copy of add x9, x9, #16
  • mov z1.q, q1 - we do this for correctness to broadcast the needles from a segment to a full SVE vector, but this is not necessary here since we effectively only work with the first 128-bit segment of the SVE registers
  • ptrue p0.b - this is introduced by experimental.cttz.elts but is technically not needed in this case since we could reuse the previously defined p0

None of these issues is related to the transformation itself (and don't affect its performance that much anyway), so I think they can be addressed separately.

@sjoerdmeijer
Copy link
Collaborator

Thanks for the experiments and data, I think that looks really good! This always looks like a win, except when the size is or 1 or 2. That is most likely so exceptional we don't have to consider runtime size checks to either fall into the scalar or the vector loop.

@rj-jesus
Copy link
Contributor Author

Hi @sjoerdmeijer, thanks very much for the feedback! I'll update the patch later this week.

@rj-jesus
Copy link
Contributor Author

Thanks again for your feedback @sjoerdmeijer, I believe I've addressed most of your comments in my latest commit.
Please let me know if I've missed anything or if you have any other comments!

Copy link
Collaborator

@sjoerdmeijer sjoerdmeijer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me.

@rj-jesus rj-jesus force-pushed the rjj/aarch64-find-first-byte branch from a3292d1 to 1b02d46 Compare December 11, 2024 11:41
@rj-jesus
Copy link
Contributor Author

rj-jesus commented Dec 11, 2024

This looks good to me.

Many thanks, @sjoerdmeijer! I've just made the changes you suggested and rebased this on 2fe30bc, please let me know if you have any other comments.

Also, if anyone else is planning to have a look at this, could you please give me a shout?
Otherwise I'll merge this in the next couple of days.

Thanks again!

@david-arm
Copy link
Contributor

The memory checks for the inner loop only need to be done once I assume, since I'd expect the start value for the needle to be loop invariant? If not, you'll need to calculate the entire needle range for the whole outer loop. I'd expect a minor hit to performance, but given the fantastic improvements you're reporting I'd still expect it to be a nice win. :) For the mismatch case the trip counts were typically quite small and the memory checks reduced the performance from around 8% -> 7% or something like that.

@rj-jesus
Copy link
Contributor Author

The patch is looking mostly good. I've still got a bit more to review, but at the moment there is a serious bug with this so can you hold off landing this until it's fixed?

Of course, I will do, thank you very much for the heads up!

The inner loop can exit early before you've loaded the rest of the needle so it needs memory checks before entering the vector loop to ensure the needle doesn't cross a page boundary.

In fact, I think we need page boundary checks for the pointer in the outer loop as well. Suppose you find a match in your needle after only two iterations in your scalar outer loop? Your third iteration could cross a page boundary and seg fault. The scalar version of the loop will be fine because we exited before the fault, but the vector version will crash.

I think I'm missing something. We should never be reading beyond the end of the search or needle arrays (if we do, that's definitely a bug), so I don't think we would be prone to segfaulting. The idea was to use llvm.get.active.lane.mask to create a valid predicate for the loads, shouldn't that work?

@david-arm
Copy link
Contributor

I think I'm missing something. We should never be reading beyond the end of the search or needle arrays (if we do, that's definitely a bug), so I don't think we would be prone to segfaulting. The idea was to use llvm.get.active.lane.mask to create a valid predicate for the loads, shouldn't that work?

Doesn't your predicate simply ensure you never load more data than the range of the loop, but it doesn't protect you against seg faults? It is possible to use masks to avoid seg faults, but you'd have to split your loops up page by page which is pretty complicated. Consider one of the example C codes in your test:

;   char* find_first_of(char *first, char *last, char *s_first, char *s_last) {
;     for (; first != last; ++first)
;       for (char *it = s_first; it != s_last; ++it)
;         if (*first == *it)
;           return first;
;     return last;
;   }

The scalar loop may find a match if (*first == *it) on the first iteration of your inner loop. Suppose the pointer it starts at the last byte of a page and (s_last - s_first) = 16. In the scalar version you'll exit the loop before cross the page boundary so you won't seg fault, but in the vector version you'll load all 16 bytes which cross a page boundary. If the second page isn't mapped in memory the version will seg fault. This is the fundamental problem with vectorising loops that contain early exits.

@david-arm
Copy link
Contributor

Alternatively, if you can prove at compile time the vector loop won't fault you can also do that as a first step, then have a follow-on patch that enables loop versioning. For example, if you have some IR where the pointers all come from an alloca and you know the trip counts for both loops you can use llvm::isDereferenceableAndAlignedInLoop to ask the question if the load pointer can be unconditionally dereferenced.

@rj-jesus
Copy link
Contributor Author

Doesn't your predicate simply ensure you never load more data than the range of the loop, but it doesn't protect you against seg faults?

Ah, of course, I see what you mean now, thanks very much for catching this! I'll add the guards and incorporate your other comments. Thank you again!

@rj-jesus rj-jesus force-pushed the rjj/aarch64-find-first-byte branch from 1b02d46 to ca1a373 Compare January 2, 2025 12:36
@rj-jesus
Copy link
Contributor Author

rj-jesus commented Jan 2, 2025

Many thanks for your comments, @david-arm, I should have addressed them in the latest revision of the patch and added runtime page boundary checks as you had pointed out. Could you please let me know if you have any other suggestions?

@rj-jesus
Copy link
Contributor Author

Ping :)

@appujee
Copy link
Contributor

appujee commented Jan 28, 2025

@david-arm do you have any pending comments or we can merge it?

@alexey-bataev alexey-bataev requested a review from arcbbb January 29, 2025 17:13
This patch adds a new loop to LoopIdiomVectorizePass, enabling it to
recognise and use @llvm.experimental.vector.match to vectorise loops
such as:

    char* find_first_of(char *first, char *last,
                        char *s_first, char *s_last) {
      for (; first != last; ++first)
        for (char *it = s_first; it != s_last; ++it)
          if (*first == *it)
            return first;
      return last;
    }

These loops match the C++ standard library's std::find_first_of.
@rj-jesus rj-jesus force-pushed the rjj/aarch64-find-first-byte branch from ca1a373 to db2fbc6 Compare February 5, 2025 14:13
@rj-jesus
Copy link
Contributor Author

rj-jesus commented Feb 5, 2025

Many thanks for the comments, @david-arm. I've tidied up the tests and added the checks for loop invariance (though I think they are redundant given how the IR is matched before, but in any case). Please let me know if that's what you had in mind and if you have any other comments. :)

Copy link
Contributor

@david-arm david-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great - thanks for being patient with all the review comments! I have a few more minor comments, mostly around adding a couple more tests, but then I think it's ready to go!

@rj-jesus
Copy link
Contributor Author

rj-jesus commented Feb 6, 2025

Thanks very much for the feedback, @david-arm! I've addressed most of your comments, please let me know if you'd like me to add anything else!

Copy link
Contributor

@david-arm david-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@rj-jesus
Copy link
Contributor Author

rj-jesus commented Feb 6, 2025

Many thanks for the review! I'll let the tests run and merge it afterwards. 😄

@hiraditya
Copy link
Collaborator

Thanks for working on this. I wonder what other algorithms can benefit from the match sve2 instruction.

@rj-jesus
Copy link
Contributor Author

rj-jesus commented Feb 7, 2025

Hi @hiraditya, I think it can be useful for parsers and some compression algorithms. Do you have any particular use cases in mind?
(I see the tests I green, but I think I rather wait for next Monday to merge this in case any issues pop up!)

@rj-jesus rj-jesus merged commit 5f84b6e into llvm:main Feb 10, 2025
8 checks passed
@hiraditya
Copy link
Collaborator

Hi @hiraditya, I think it can be useful for parsers and some compression algorithms. Do you have any particular use cases in mind? (I see the tests I green, but I think I rather wait for next Monday to merge this in case any issues pop up!)

no i was just curious. thanks.

Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
This patch adds a new loop to LoopIdiomVectorizePass, enabling it to
recognise and vectorise loops such as:
```cpp
template<class InputIt, class ForwardIt>
InputIt find_first_of(InputIt first, InputIt last,
                      ForwardIt s_first, ForwardIt s_last)
{
  for (; first != last; ++first)
    for (ForwardIt it = s_first; it != s_last; ++it)
      if (*first == *it)
        return first;
  return last;
}
```

These loops match the C++ standard library function `std::find_first_of`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 llvm:analysis Includes value tracking, cost tables and constant folding llvm:ir llvm:SelectionDAG SelectionDAGISel as well llvm:transforms vectorizers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants