Skip to content

Commit e9bd6d4

Browse files
committed
[AArch64] Add @llvm.experimental.vector.match
This patch introduces an experimental intrinsic for matching the elements of one vector against the elements of another. For AArch64 targets that support SVE2, it lowers to a MATCH instruction for supported fixed and scalar types. Otherwise, the intrinsic has generic lowering in SelectionDAGBuilder.
1 parent 4614b80 commit e9bd6d4

File tree

11 files changed

+455
-0
lines changed

11 files changed

+455
-0
lines changed

llvm/docs/LangRef.rst

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20043,6 +20043,45 @@ are undefined.
2004320043
}
2004420044

2004520045

20046+
'``llvm.experimental.vector.match.*``' Intrinsic
20047+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
20048+
20049+
Syntax:
20050+
"""""""
20051+
20052+
This is an overloaded intrinsic. Support for specific vector types is target
20053+
dependent.
20054+
20055+
::
20056+
20057+
declare <<n> x i1> @llvm.experimental.vector.match(<<n> x <ty>> %op1, <<m> x <ty>> %op2, <<n> x i1> %mask)
20058+
declare <vscale x <n> x i1> @llvm.experimental.vector.match(<vscale x <n> x <ty>> %op1, <<m> x <ty>> %op2, <vscale x <n> x i1> %mask)
20059+
20060+
Overview:
20061+
"""""""""
20062+
20063+
Find active elements of the first argument matching any elements of the second.
20064+
20065+
Arguments:
20066+
""""""""""
20067+
20068+
The first argument is the search vector, the second argument the vector of
20069+
elements we are searching for (i.e. for which we consider a match successful),
20070+
and the third argument is a mask that controls which elements of the first
20071+
argument are active.
20072+
20073+
Semantics:
20074+
""""""""""
20075+
20076+
The '``llvm.experimental.vector.match``' intrinsic compares each active element
20077+
in the first argument against the elements of the second argument, placing
20078+
``1`` in the corresponding element of the output vector if any comparison is
20079+
successful, and ``0`` otherwise. Inactive elements in the mask are set to ``0``
20080+
in the output.
20081+
20082+
The second argument needs to be a fixed-length vector with the same element
20083+
type as the first argument.
20084+
2004620085
Matrix Intrinsics
2004720086
-----------------
2004820087

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1771,6 +1771,11 @@ class TargetTransformInfo {
17711771
/// This should also apply to lowering for vector funnel shifts (rotates).
17721772
bool isVectorShiftByScalarCheap(Type *Ty) const;
17731773

1774+
/// \returns True if the target has hardware support for vector match
1775+
/// operations between vectors of type `VT` and search vectors of `SearchSize`
1776+
/// elements, and false otherwise.
1777+
bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const;
1778+
17741779
struct VPLegalization {
17751780
enum VPTransform {
17761781
// keep the predicating parameter
@@ -2221,6 +2226,7 @@ class TargetTransformInfo::Concept {
22212226
SmallVectorImpl<Use *> &OpsToSink) const = 0;
22222227

22232228
virtual bool isVectorShiftByScalarCheap(Type *Ty) const = 0;
2229+
virtual bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const = 0;
22242230
virtual VPLegalization
22252231
getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
22262232
virtual bool hasArmWideBranch(bool Thumb) const = 0;
@@ -3014,6 +3020,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
30143020
return Impl.isVectorShiftByScalarCheap(Ty);
30153021
}
30163022

3023+
bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const override {
3024+
return Impl.hasVectorMatch(VT, SearchSize);
3025+
}
3026+
30173027
VPLegalization
30183028
getVPLegalizationStrategy(const VPIntrinsic &PI) const override {
30193029
return Impl.getVPLegalizationStrategy(PI);

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,10 @@ class TargetTransformInfoImplBase {
995995

996996
bool isVectorShiftByScalarCheap(Type *Ty) const { return false; }
997997

998+
bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const {
999+
return false;
1000+
}
1001+
9981002
TargetTransformInfo::VPLegalization
9991003
getVPLegalizationStrategy(const VPIntrinsic &PI) const {
10001004
return TargetTransformInfo::VPLegalization(

llvm/include/llvm/IR/Intrinsics.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1918,6 +1918,14 @@ def int_experimental_vector_histogram_add : DefaultAttrsIntrinsic<[],
19181918
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], // Mask
19191919
[ IntrArgMemOnly ]>;
19201920

1921+
// Experimental match
1922+
def int_experimental_vector_match : DefaultAttrsIntrinsic<
1923+
[ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ],
1924+
[ llvm_anyvector_ty,
1925+
llvm_anyvector_ty,
1926+
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ], // Mask
1927+
[ IntrNoMem, IntrNoSync, IntrWillReturn ]>;
1928+
19211929
// Operators
19221930
let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn] in {
19231931
// Integer arithmetic

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,11 @@ bool TargetTransformInfo::isVectorShiftByScalarCheap(Type *Ty) const {
13831383
return TTIImpl->isVectorShiftByScalarCheap(Ty);
13841384
}
13851385

1386+
bool TargetTransformInfo::hasVectorMatch(VectorType *VT,
1387+
unsigned SearchSize) const {
1388+
return TTIImpl->hasVectorMatch(VT, SearchSize);
1389+
}
1390+
13861391
TargetTransformInfo::Concept::~Concept() = default;
13871392

13881393
TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {}

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8156,6 +8156,42 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
81568156
DAG.getNode(ISD::EXTRACT_SUBVECTOR, sdl, ResultVT, Vec, Index));
81578157
return;
81588158
}
8159+
case Intrinsic::experimental_vector_match: {
8160+
SDValue Op1 = getValue(I.getOperand(0));
8161+
SDValue Op2 = getValue(I.getOperand(1));
8162+
SDValue Mask = getValue(I.getOperand(2));
8163+
EVT Op1VT = Op1.getValueType();
8164+
EVT Op2VT = Op2.getValueType();
8165+
EVT ResVT = Mask.getValueType();
8166+
unsigned SearchSize = Op2VT.getVectorNumElements();
8167+
8168+
LLVMContext &Ctx = *DAG.getContext();
8169+
const auto &TTI =
8170+
TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
8171+
8172+
// If the target has native support for this vector match operation, lower
8173+
// the intrinsic directly; otherwise, lower it below.
8174+
if (TTI.hasVectorMatch(cast<VectorType>(Op1VT.getTypeForEVT(Ctx)),
8175+
SearchSize)) {
8176+
visitTargetIntrinsic(I, Intrinsic);
8177+
return;
8178+
}
8179+
8180+
SDValue Ret = DAG.getNode(ISD::SPLAT_VECTOR, sdl, ResVT,
8181+
DAG.getConstant(0, sdl, MVT::i1));
8182+
8183+
for (unsigned i = 0; i < SearchSize; ++i) {
8184+
SDValue Op2Elem = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, sdl,
8185+
Op2VT.getVectorElementType(), Op2,
8186+
DAG.getVectorIdxConstant(i, sdl));
8187+
SDValue Splat = DAG.getNode(ISD::SPLAT_VECTOR, sdl, Op1VT, Op2Elem);
8188+
SDValue Cmp = DAG.getSetCC(sdl, ResVT, Op1, Splat, ISD::SETEQ);
8189+
Ret = DAG.getNode(ISD::OR, sdl, ResVT, Ret, Cmp);
8190+
}
8191+
8192+
setValue(&I, DAG.getNode(ISD::AND, sdl, ResVT, Ret, Mask));
8193+
return;
8194+
}
81598195
case Intrinsic::vector_reverse:
81608196
visitVectorReverse(I);
81618197
return;

llvm/lib/IR/Verifier.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6154,6 +6154,27 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
61546154
&Call);
61556155
break;
61566156
}
6157+
case Intrinsic::experimental_vector_match: {
6158+
Value *Op1 = Call.getArgOperand(0);
6159+
Value *Op2 = Call.getArgOperand(1);
6160+
Value *Mask = Call.getArgOperand(2);
6161+
6162+
VectorType *Op1Ty = dyn_cast<VectorType>(Op1->getType());
6163+
VectorType *Op2Ty = dyn_cast<VectorType>(Op2->getType());
6164+
VectorType *MaskTy = dyn_cast<VectorType>(Mask->getType());
6165+
6166+
Check(Op1Ty && Op2Ty && MaskTy, "Operands must be vectors.", &Call);
6167+
Check(!isa<ScalableVectorType>(Op2Ty), "Second operand cannot be scalable.",
6168+
&Call);
6169+
Check(Op1Ty->getElementType() == Op2Ty->getElementType(),
6170+
"First two operands must have the same element type.", &Call);
6171+
Check(Op1Ty->getElementCount() == MaskTy->getElementCount(),
6172+
"First operand and mask must have the same number of elements.",
6173+
&Call);
6174+
Check(MaskTy->getElementType()->isIntegerTy(1),
6175+
"Mask must be a vector of i1's.", &Call);
6176+
break;
6177+
}
61576178
case Intrinsic::vector_insert: {
61586179
Value *Vec = Call.getArgOperand(0);
61596180
Value *SubVec = Call.getArgOperand(1);

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6364,6 +6364,58 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
63646364
DAG.getNode(AArch64ISD::CTTZ_ELTS, dl, MVT::i64, CttzOp);
63656365
return DAG.getZExtOrTrunc(NewCttzElts, dl, Op.getValueType());
63666366
}
6367+
case Intrinsic::experimental_vector_match: {
6368+
SDValue ID =
6369+
DAG.getTargetConstant(Intrinsic::aarch64_sve_match, dl, MVT::i64);
6370+
6371+
auto Op1 = Op.getOperand(1);
6372+
auto Op2 = Op.getOperand(2);
6373+
auto Mask = Op.getOperand(3);
6374+
6375+
EVT Op1VT = Op1.getValueType();
6376+
EVT Op2VT = Op2.getValueType();
6377+
EVT ResVT = Op.getValueType();
6378+
6379+
assert((Op1VT.getVectorElementType() == MVT::i8 ||
6380+
Op1VT.getVectorElementType() == MVT::i16) &&
6381+
"Expected 8-bit or 16-bit characters.");
6382+
assert(!Op2VT.isScalableVector() && "Search vector cannot be scalable.");
6383+
assert(Op1VT.getVectorElementType() == Op2VT.getVectorElementType() &&
6384+
"Operand type mismatch.");
6385+
assert(Op1VT.getVectorMinNumElements() == Op2VT.getVectorNumElements() &&
6386+
"Invalid operands.");
6387+
6388+
// Wrap the search vector in a scalable vector.
6389+
EVT OpContainerVT = getContainerForFixedLengthVector(DAG, Op2VT);
6390+
Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
6391+
6392+
// If the result is scalable, we need to broadbast the search vector across
6393+
// the SVE register and then carry out the MATCH.
6394+
if (ResVT.isScalableVector()) {
6395+
Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
6396+
DAG.getTargetConstant(0, dl, MVT::i64));
6397+
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, ResVT, ID, Mask, Op1,
6398+
Op2);
6399+
}
6400+
6401+
// If the result is fixed, we can still use MATCH but we need to wrap the
6402+
// first operand and the mask in scalable vectors before doing so.
6403+
EVT MatchVT = OpContainerVT.changeElementType(MVT::i1);
6404+
6405+
// Wrap the operands.
6406+
Op1 = convertToScalableVector(DAG, OpContainerVT, Op1);
6407+
Mask = DAG.getNode(ISD::ANY_EXTEND, dl, Op1VT, Mask);
6408+
Mask = convertFixedMaskToScalableVector(Mask, DAG);
6409+
6410+
// Carry out the match.
6411+
SDValue Match =
6412+
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MatchVT, ID, Mask, Op1, Op2);
6413+
6414+
// Extract and return the result.
6415+
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op1VT,
6416+
DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match),
6417+
DAG.getVectorIdxConstant(0, dl));
6418+
}
63676419
}
63686420
}
63696421

@@ -27046,6 +27098,7 @@ void AArch64TargetLowering::ReplaceNodeResults(
2704627098
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
2704727099
return;
2704827100
}
27101+
case Intrinsic::experimental_vector_match:
2704927102
case Intrinsic::get_active_lane_mask: {
2705027103
if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1)
2705127104
return;

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4072,6 +4072,30 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction(
40724072
}
40734073
}
40744074

4075+
bool AArch64TTIImpl::hasVectorMatch(VectorType *VT, unsigned SearchSize) const {
4076+
// Check that (i) the target has SVE2 and SVE is available, (ii) `VT' is a
4077+
// legal type for MATCH, and (iii) the search vector can be broadcast
4078+
// efficently to a legal type.
4079+
//
4080+
// Currently, we require the length of the search vector to match the minimum
4081+
// number of elements of `VT'. In practice this means we only support the
4082+
// cases (nxv16i8, 16), (v16i8, 16), (nxv8i16, 8), and (v8i16, 8), where the
4083+
// first element of the tuples corresponds to the type of the first argument
4084+
// and the second the length of the search vector.
4085+
//
4086+
// In the future we can support more cases. For example, (nxv16i8, 4) could
4087+
// be efficiently supported by using a DUP.S to broadcast the search
4088+
// elements, and more exotic cases like (nxv16i8, 5) could be supported by a
4089+
// sequence of SEL(DUP).
4090+
if (ST->hasSVE2() && ST->isSVEAvailable() &&
4091+
VT->getPrimitiveSizeInBits().getKnownMinValue() == 128 &&
4092+
(VT->getElementCount().getKnownMinValue() == 8 ||
4093+
VT->getElementCount().getKnownMinValue() == 16) &&
4094+
VT->getElementCount().getKnownMinValue() == SearchSize)
4095+
return true;
4096+
return false;
4097+
}
4098+
40754099
InstructionCost
40764100
AArch64TTIImpl::getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty,
40774101
FastMathFlags FMF,

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
392392
return ST->hasSVE();
393393
}
394394

395+
bool hasVectorMatch(VectorType *VT, unsigned SearchSize) const;
396+
395397
InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
396398
std::optional<FastMathFlags> FMF,
397399
TTI::TargetCostKind CostKind);

0 commit comments

Comments
 (0)