Skip to content

Commit 3e95323

Browse files
committed
[AArch64] Add @llvm.experimental.vector.match
This patch introduces an experimental intrinsic for matching the elements of one vector against the elements of another. For AArch64 targets that support SVE2, it lowers to a MATCH instruction for supported fixed and scalar types.
1 parent 9a222a1 commit 3e95323

File tree

11 files changed

+233
-0
lines changed

11 files changed

+233
-0
lines changed

llvm/docs/LangRef.rst

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19958,6 +19958,59 @@ are undefined.
1995819958
}
1995919959

1996019960

19961+
'``llvm.experimental.vector.match.*``' Intrinsic
19962+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
19963+
19964+
Syntax:
19965+
"""""""
19966+
19967+
This is an overloaded intrinsic. Support for specific vector types is target
19968+
dependent.
19969+
19970+
::
19971+
19972+
declare <<n> x i1> @llvm.experimental.vector.match(<<n> x <ty>> %op1, <<n> x <ty>> %op2, <<n> x i1> %mask, i32 <segsize>)
19973+
declare <vscale x <n> x i1> @llvm.experimental.vector.match(<vscale x <n> x <ty>> %op1, <vscale x <n> x <ty>> %op2, <vscale x <n> x i1> %mask, i32 <segsize>)
19974+
19975+
Overview:
19976+
"""""""""
19977+
19978+
Find elements of the first argument matching any elements of the second.
19979+
19980+
Arguments:
19981+
""""""""""
19982+
19983+
The first argument is the search vector, the second argument is the vector of
19984+
elements we are searching for (i.e. for which we consider a match successful),
19985+
and the third argument is a mask that controls which elements of the first
19986+
argument are active. The fourth argument is an immediate that sets the segment
19987+
size for the search window.
19988+
19989+
Semantics:
19990+
""""""""""
19991+
19992+
The '``llvm.experimental.vector.match``' intrinsic compares each element in the
19993+
first argument against potentially several elements of the second, placing
19994+
``1`` in the corresponding element of the output vector if any comparison is
19995+
successful, and ``0`` otherwise. Inactive elements in the mask are set to ``0``
19996+
in the output. The segment size controls the number of elements of the second
19997+
argument that are compared against.
19998+
19999+
For example, for vectors with 16 elements, if ``segsize = 16`` then each
20000+
element of the first argument is compared against all 16 elements of the second
20001+
argument; but if ``segsize = 4``, then each of the first four elements of the
20002+
first argument is compared against the first four elements of the second
20003+
argument, each of the second four elements of the first argument is compared
20004+
against the second four elements of the second argument, and so forth.
20005+
20006+
Currently, ``segsize`` needs to be an immediate value. The special value of
20007+
``-1`` is allowed to indicate all elements should be searched.
20008+
20009+
Support for specific vector types is target dependent. For AArch64 targets with
20010+
SVE2 support, the intrinsic is valid on ``<16 x i8>`` or ``<8 x i16>`` vectors,
20011+
or the scalable equivalents, with a ``segsize`` equal to the known minimum
20012+
number of elements of the vectors (16 or 8, respectively).
20013+
1996120014
Matrix Intrinsics
1996220015
-----------------
1996320016

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1744,6 +1744,10 @@ class TargetTransformInfo {
17441744
bool hasActiveVectorLength(unsigned Opcode, Type *DataType,
17451745
Align Alignment) const;
17461746

1747+
/// \returns Returns true if the target supports vector match operations for
1748+
/// the vector type `VT` using a segment size of `SegSize`.
1749+
bool hasVectorMatch(VectorType *VT, unsigned SegSize) const;
1750+
17471751
struct VPLegalization {
17481752
enum VPTransform {
17491753
// keep the predicating parameter
@@ -2182,6 +2186,7 @@ class TargetTransformInfo::Concept {
21822186
virtual bool supportsScalableVectors() const = 0;
21832187
virtual bool hasActiveVectorLength(unsigned Opcode, Type *DataType,
21842188
Align Alignment) const = 0;
2189+
virtual bool hasVectorMatch(VectorType *VT, unsigned SegSize) const = 0;
21852190
virtual VPLegalization
21862191
getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
21872192
virtual bool hasArmWideBranch(bool Thumb) const = 0;
@@ -2952,6 +2957,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
29522957
return Impl.hasActiveVectorLength(Opcode, DataType, Alignment);
29532958
}
29542959

2960+
bool hasVectorMatch(VectorType *VT, unsigned SegSize) const override {
2961+
return Impl.hasVectorMatch(VT, SegSize);
2962+
}
2963+
29552964
VPLegalization
29562965
getVPLegalizationStrategy(const VPIntrinsic &PI) const override {
29572966
return Impl.getVPLegalizationStrategy(PI);

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,8 @@ class TargetTransformInfoImplBase {
972972
return false;
973973
}
974974

975+
bool hasVectorMatch(VectorType *VT, unsigned SegSize) const { return false; }
976+
975977
TargetTransformInfo::VPLegalization
976978
getVPLegalizationStrategy(const VPIntrinsic &PI) const {
977979
return TargetTransformInfo::VPLegalization(

llvm/include/llvm/IR/Intrinsics.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1912,6 +1912,16 @@ def int_experimental_vector_histogram_add : DefaultAttrsIntrinsic<[],
19121912
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], // Mask
19131913
[ IntrArgMemOnly ]>;
19141914

1915+
// Experimental match
1916+
def int_experimental_vector_match : DefaultAttrsIntrinsic<
1917+
[ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ],
1918+
[ llvm_anyvector_ty,
1919+
LLVMMatchType<0>,
1920+
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, // Mask
1921+
llvm_i32_ty ], // Segment size
1922+
[ IntrNoMem, IntrNoSync, IntrWillReturn,
1923+
ImmArg<ArgIndex<3>> ]>;
1924+
19151925
// Operators
19161926
let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn] in {
19171927
// Integer arithmetic

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,6 +1354,11 @@ bool TargetTransformInfo::hasActiveVectorLength(unsigned Opcode, Type *DataType,
13541354
return TTIImpl->hasActiveVectorLength(Opcode, DataType, Alignment);
13551355
}
13561356

1357+
bool TargetTransformInfo::hasVectorMatch(VectorType *VT,
1358+
unsigned SegSize) const {
1359+
return TTIImpl->hasVectorMatch(VT, SegSize);
1360+
}
1361+
13571362
TargetTransformInfo::Concept::~Concept() = default;
13581363

13591364
TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {}

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8137,6 +8137,15 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
81378137
DAG.getNode(ISD::EXTRACT_SUBVECTOR, sdl, ResultVT, Vec, Index));
81388138
return;
81398139
}
8140+
case Intrinsic::experimental_vector_match: {
8141+
auto *VT = dyn_cast<VectorType>(I.getOperand(0)->getType());
8142+
auto SegmentSize = cast<ConstantInt>(I.getOperand(3))->getLimitedValue();
8143+
const auto &TTI =
8144+
TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
8145+
assert(VT && TTI.hasVectorMatch(VT, SegmentSize) && "Unsupported type!");
8146+
visitTargetIntrinsic(I, Intrinsic);
8147+
return;
8148+
}
81408149
case Intrinsic::vector_reverse:
81418150
visitVectorReverse(I);
81428151
return;

llvm/lib/IR/Verifier.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6108,6 +6108,34 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
61086108
&Call);
61096109
break;
61106110
}
6111+
case Intrinsic::experimental_vector_match: {
6112+
Value *Op1 = Call.getArgOperand(0);
6113+
Value *Op2 = Call.getArgOperand(1);
6114+
Value *Mask = Call.getArgOperand(2);
6115+
Value *SegSize = Call.getArgOperand(3);
6116+
6117+
VectorType *OpTy = dyn_cast<VectorType>(Op1->getType());
6118+
VectorType *MaskTy = dyn_cast<VectorType>(Mask->getType());
6119+
Check(OpTy && MaskTy, "experimental.vector.match operands are not vectors.",
6120+
&Call);
6121+
Check(Op2->getType() == OpTy,
6122+
"experimental.vector.match first two operands must have matching "
6123+
"types.",
6124+
&Call);
6125+
Check(isa<ConstantInt>(SegSize),
6126+
"experimental.vector.match segment size needs to be an immediate "
6127+
"integer.",
6128+
&Call);
6129+
6130+
ElementCount EC = OpTy->getElementCount();
6131+
Check(MaskTy->getElementCount() == EC,
6132+
"experimental.vector.match mask must have the same number of "
6133+
"elements as the remaining vector operands.",
6134+
&Call);
6135+
Check(MaskTy->getElementType()->isIntegerTy(1),
6136+
"experimental.vector.match mask element type is not i1.", &Call);
6137+
break;
6138+
}
61116139
case Intrinsic::vector_insert: {
61126140
Value *Vec = Call.getArgOperand(0);
61136141
Value *SubVec = Call.getArgOperand(1);

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6255,6 +6255,51 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
62556255
DAG.getNode(AArch64ISD::CTTZ_ELTS, dl, MVT::i64, CttzOp);
62566256
return DAG.getZExtOrTrunc(NewCttzElts, dl, Op.getValueType());
62576257
}
6258+
case Intrinsic::experimental_vector_match: {
6259+
SDValue ID =
6260+
DAG.getTargetConstant(Intrinsic::aarch64_sve_match, dl, MVT::i64);
6261+
6262+
auto Op1 = Op.getOperand(1);
6263+
auto Op2 = Op.getOperand(2);
6264+
auto Mask = Op.getOperand(3);
6265+
auto SegmentSize =
6266+
cast<ConstantSDNode>(Op.getOperand(4))->getLimitedValue();
6267+
6268+
EVT VT = Op.getValueType();
6269+
auto MinNumElts = VT.getVectorMinNumElements();
6270+
6271+
assert(Op1.getValueType() == Op2.getValueType() && "Type mismatch.");
6272+
assert(Op1.getValueSizeInBits().getKnownMinValue() == 128 &&
6273+
"Custom lower only works on 128-bit segments.");
6274+
assert((Op1.getValueType().getVectorElementType() == MVT::i8 ||
6275+
Op1.getValueType().getVectorElementType() == MVT::i16) &&
6276+
"Custom lower only supports 8-bit or 16-bit characters.");
6277+
assert(SegmentSize == MinNumElts && "Custom lower needs segment size to "
6278+
"match minimum number of elements.");
6279+
6280+
if (VT.isScalableVector())
6281+
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, ID, Mask, Op1, Op2);
6282+
6283+
// We can use the SVE2 match instruction to lower this intrinsic by
6284+
// converting the operands to scalable vectors, doing a match, and then
6285+
// extracting a fixed-width subvector from the scalable vector.
6286+
6287+
EVT OpVT = Op1.getValueType();
6288+
EVT OpContainerVT = getContainerForFixedLengthVector(DAG, OpVT);
6289+
EVT MatchVT = OpContainerVT.changeElementType(MVT::i1);
6290+
6291+
auto ScalableOp1 = convertToScalableVector(DAG, OpContainerVT, Op1);
6292+
auto ScalableOp2 = convertToScalableVector(DAG, OpContainerVT, Op2);
6293+
auto ScalableMask = DAG.getNode(ISD::SIGN_EXTEND, dl, OpVT, Mask);
6294+
ScalableMask = convertFixedMaskToScalableVector(ScalableMask, DAG);
6295+
6296+
SDValue Match = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MatchVT, ID,
6297+
ScalableMask, ScalableOp1, ScalableOp2);
6298+
6299+
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT,
6300+
DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match),
6301+
DAG.getVectorIdxConstant(0, dl));
6302+
}
62586303
}
62596304
}
62606305

@@ -27304,6 +27349,7 @@ void AArch64TargetLowering::ReplaceNodeResults(
2730427349
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
2730527350
return;
2730627351
}
27352+
case Intrinsic::experimental_vector_match:
2730727353
case Intrinsic::get_active_lane_mask: {
2730827354
if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1)
2730927355
return;

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4041,6 +4041,18 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction(
40414041
}
40424042
}
40434043

4044+
bool AArch64TTIImpl::hasVectorMatch(VectorType *VT, unsigned SegSize) const {
4045+
// Check that the target has SVE2 (and SVE is available), that `VT' is a
4046+
// legal type for MATCH, and that the segment size is 128-bit.
4047+
if (ST->hasSVE2() && ST->isSVEAvailable() &&
4048+
VT->getPrimitiveSizeInBits().getKnownMinValue() == 128 &&
4049+
VT->getElementCount().getKnownMinValue() == SegSize &&
4050+
(VT->getElementCount().getKnownMinValue() == 8 ||
4051+
VT->getElementCount().getKnownMinValue() == 16))
4052+
return true;
4053+
return false;
4054+
}
4055+
40444056
InstructionCost
40454057
AArch64TTIImpl::getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty,
40464058
FastMathFlags FMF,

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
392392
return ST->hasSVE();
393393
}
394394

395+
bool hasVectorMatch(VectorType *VT, unsigned SegSize) const;
396+
395397
InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
396398
std::optional<FastMathFlags> FMF,
397399
TTI::TargetCostKind CostKind);
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 3
2+
; RUN: llc -mtriple=aarch64 < %s -o - | FileCheck %s
3+
4+
define <vscale x 16 x i1> @match_nxv16i8(<vscale x 16 x i8> %op1, <vscale x 16 x i8> %op2, <vscale x 16 x i1> %mask) #0 {
5+
; CHECK-LABEL: match_nxv16i8:
6+
; CHECK: // %bb.0:
7+
; CHECK-NEXT: match p0.b, p0/z, z0.b, z1.b
8+
; CHECK-NEXT: ret
9+
%r = tail call <vscale x 16 x i1> @llvm.experimental.vector.match(<vscale x 16 x i8> %op1, <vscale x 16 x i8> %op2, <vscale x 16 x i1> %mask, i32 16)
10+
ret <vscale x 16 x i1> %r
11+
}
12+
13+
define <vscale x 8 x i1> @match_nxv8i16(<vscale x 8 x i16> %op1, <vscale x 8 x i16> %op2, <vscale x 8 x i1> %mask) #0 {
14+
; CHECK-LABEL: match_nxv8i16:
15+
; CHECK: // %bb.0:
16+
; CHECK-NEXT: match p0.h, p0/z, z0.h, z1.h
17+
; CHECK-NEXT: ret
18+
%r = tail call <vscale x 8 x i1> @llvm.experimental.vector.match(<vscale x 8 x i16> %op1, <vscale x 8 x i16> %op2, <vscale x 8 x i1> %mask, i32 8)
19+
ret <vscale x 8 x i1> %r
20+
}
21+
22+
define <16 x i1> @match_v16i8(<16 x i8> %op1, <16 x i8> %op2, <16 x i1> %mask) #0 {
23+
; CHECK-LABEL: match_v16i8:
24+
; CHECK: // %bb.0:
25+
; CHECK-NEXT: shl v2.16b, v2.16b, #7
26+
; CHECK-NEXT: ptrue p0.b, vl16
27+
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
28+
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
29+
; CHECK-NEXT: cmlt v2.16b, v2.16b, #0
30+
; CHECK-NEXT: cmpne p0.b, p0/z, z2.b, #0
31+
; CHECK-NEXT: match p0.b, p0/z, z0.b, z1.b
32+
; CHECK-NEXT: mov z0.b, p0/z, #-1 // =0xffffffffffffffff
33+
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
34+
; CHECK-NEXT: ret
35+
%r = tail call <16 x i1> @llvm.experimental.vector.match(<16 x i8> %op1, <16 x i8> %op2, <16 x i1> %mask, i32 16)
36+
ret <16 x i1> %r
37+
}
38+
39+
define <8 x i1> @match_v8i16(<8 x i16> %op1, <8 x i16> %op2, <8 x i1> %mask) #0 {
40+
; CHECK-LABEL: match_v8i16:
41+
; CHECK: // %bb.0:
42+
; CHECK-NEXT: ushll v2.8h, v2.8b, #0
43+
; CHECK-NEXT: ptrue p0.h, vl8
44+
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
45+
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
46+
; CHECK-NEXT: shl v2.8h, v2.8h, #15
47+
; CHECK-NEXT: cmlt v2.8h, v2.8h, #0
48+
; CHECK-NEXT: cmpne p0.h, p0/z, z2.h, #0
49+
; CHECK-NEXT: match p0.h, p0/z, z0.h, z1.h
50+
; CHECK-NEXT: mov z0.h, p0/z, #-1 // =0xffffffffffffffff
51+
; CHECK-NEXT: xtn v0.8b, v0.8h
52+
; CHECK-NEXT: ret
53+
%r = tail call <8 x i1> @llvm.experimental.vector.match(<8 x i16> %op1, <8 x i16> %op2, <8 x i1> %mask, i32 8)
54+
ret <8 x i1> %r
55+
}
56+
57+
attributes #0 = { "target-features"="+sve2" }

0 commit comments

Comments
 (0)