Skip to content

Commit 3f9398d

Browse files
committed
Add support to lower partial search vectors
Add address other review comments.
1 parent e9bd6d4 commit 3f9398d

File tree

4 files changed

+424
-100
lines changed

4 files changed

+424
-100
lines changed

llvm/lib/IR/Verifier.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6173,6 +6173,8 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
61736173
&Call);
61746174
Check(MaskTy->getElementType()->isIntegerTy(1),
61756175
"Mask must be a vector of i1's.", &Call);
6176+
Check(Call.getType() == MaskTy, "Return type must match the mask type.",
6177+
&Call);
61766178
break;
61776179
}
61786180
case Intrinsic::vector_insert: {

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6379,42 +6379,86 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
63796379
assert((Op1VT.getVectorElementType() == MVT::i8 ||
63806380
Op1VT.getVectorElementType() == MVT::i16) &&
63816381
"Expected 8-bit or 16-bit characters.");
6382-
assert(!Op2VT.isScalableVector() && "Search vector cannot be scalable.");
63836382
assert(Op1VT.getVectorElementType() == Op2VT.getVectorElementType() &&
63846383
"Operand type mismatch.");
6385-
assert(Op1VT.getVectorMinNumElements() == Op2VT.getVectorNumElements() &&
6386-
"Invalid operands.");
6387-
6388-
// Wrap the search vector in a scalable vector.
6389-
EVT OpContainerVT = getContainerForFixedLengthVector(DAG, Op2VT);
6390-
Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
6391-
6392-
// If the result is scalable, we need to broadbast the search vector across
6393-
// the SVE register and then carry out the MATCH.
6394-
if (ResVT.isScalableVector()) {
6395-
Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
6396-
DAG.getTargetConstant(0, dl, MVT::i64));
6384+
assert(!Op2VT.isScalableVector() && "Search vector cannot be scalable.");
6385+
6386+
// Note: Currently Op1 needs to be v16i8, v8i16, or the scalable versions.
6387+
// In the future we could support other types (e.g. v8i8).
6388+
assert(Op1VT.getSizeInBits().getKnownMinValue() == 128 &&
6389+
"Unsupported first operand type.");
6390+
6391+
// Scalable vector type used to wrap operands.
6392+
// A single container is enough for both operands because ultimately the
6393+
// operands will have to be wrapped to the same type (nxv16i8 or nxv8i16).
6394+
EVT OpContainerVT = Op1VT.isScalableVector()
6395+
? Op1VT
6396+
: getContainerForFixedLengthVector(DAG, Op1VT);
6397+
6398+
// Wrap Op2 in a scalable register, and splat it if necessary.
6399+
if (Op1VT.getVectorMinNumElements() == Op2VT.getVectorNumElements()) {
6400+
// If Op1 and Op2 have the same number of elements we can trivially
6401+
// wrapping Op2 in an SVE register.
6402+
Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
6403+
// If the result is scalable, we need to broadcast Op2 to a full SVE
6404+
// register.
6405+
if (ResVT.isScalableVector())
6406+
Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
6407+
DAG.getTargetConstant(0, dl, MVT::i64));
6408+
} else {
6409+
// If Op1 and Op2 have different number of elements, we need to broadcast
6410+
// Op2. Ideally we would use a AArch64ISD::DUPLANE* node for this
6411+
// similarly to the above, but unfortunately it seems we are missing some
6412+
// patterns for this. So, in alternative, we splat Op2 through a splat of
6413+
// a scalable vector extract. This idiom, though a bit more verbose, is
6414+
// supported and get us the MOV instruction we want.
6415+
6416+
// Some types we need. We'll use an integer type with `Op2BitWidth' bits
6417+
// to wrap Op2 and simulate the DUPLANE.
6418+
unsigned Op2BitWidth = Op2VT.getFixedSizeInBits();
6419+
MVT Op2IntVT = MVT::getIntegerVT(Op2BitWidth);
6420+
MVT Op2FixedVT = MVT::getVectorVT(Op2IntVT, 128 / Op2BitWidth);
6421+
EVT Op2ScalableVT = getContainerForFixedLengthVector(DAG, Op2FixedVT);
6422+
// Widen Op2 to a full 128-bit register. We need this to wrap Op2 in an
6423+
// SVE register before doing the extract and splat.
6424+
// It is unlikely we'll be widening from types other than v8i8 or v4i16,
6425+
// so in practice this loop will run for a single iteration.
6426+
while (Op2VT.getFixedSizeInBits() != 128) {
6427+
Op2VT = Op2VT.getDoubleNumVectorElementsVT(*DAG.getContext());
6428+
Op2 = DAG.getNode(ISD::CONCAT_VECTORS, dl, Op2VT, Op2,
6429+
DAG.getUNDEF(Op2.getValueType()));
6430+
}
6431+
// Wrap Op2 in a scalable vector and do the splat of its 0-index lane.
6432+
Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
6433+
Op2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op2IntVT,
6434+
DAG.getBitcast(Op2ScalableVT, Op2),
6435+
DAG.getConstant(0, dl, MVT::i64));
6436+
Op2 = DAG.getSplatVector(Op2ScalableVT, dl, Op2);
6437+
Op2 = DAG.getBitcast(OpContainerVT, Op2);
6438+
}
6439+
6440+
// If the result is scalable, we just need to carry out the MATCH.
6441+
if (ResVT.isScalableVector())
63976442
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, ResVT, ID, Mask, Op1,
63986443
Op2);
6399-
}
64006444

64016445
// If the result is fixed, we can still use MATCH but we need to wrap the
64026446
// first operand and the mask in scalable vectors before doing so.
6403-
EVT MatchVT = OpContainerVT.changeElementType(MVT::i1);
64046447

64056448
// Wrap the operands.
64066449
Op1 = convertToScalableVector(DAG, OpContainerVT, Op1);
64076450
Mask = DAG.getNode(ISD::ANY_EXTEND, dl, Op1VT, Mask);
64086451
Mask = convertFixedMaskToScalableVector(Mask, DAG);
64096452

6410-
// Carry out the match.
6411-
SDValue Match =
6412-
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MatchVT, ID, Mask, Op1, Op2);
6453+
// Carry out the match and extract it.
6454+
SDValue Match = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl,
6455+
Mask.getValueType(), ID, Mask, Op1, Op2);
6456+
Match = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op1VT,
6457+
DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match),
6458+
DAG.getVectorIdxConstant(0, dl));
64136459

6414-
// Extract and return the result.
6415-
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op1VT,
6416-
DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match),
6417-
DAG.getVectorIdxConstant(0, dl));
6460+
// Truncate and return the result.
6461+
return DAG.getNode(ISD::TRUNCATE, dl, ResVT, Match);
64186462
}
64196463
}
64206464
}

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4077,21 +4077,14 @@ bool AArch64TTIImpl::hasVectorMatch(VectorType *VT, unsigned SearchSize) const {
40774077
// legal type for MATCH, and (iii) the search vector can be broadcast
40784078
// efficently to a legal type.
40794079
//
4080-
// Currently, we require the length of the search vector to match the minimum
4081-
// number of elements of `VT'. In practice this means we only support the
4082-
// cases (nxv16i8, 16), (v16i8, 16), (nxv8i16, 8), and (v8i16, 8), where the
4083-
// first element of the tuples corresponds to the type of the first argument
4084-
// and the second the length of the search vector.
4085-
//
4086-
// In the future we can support more cases. For example, (nxv16i8, 4) could
4087-
// be efficiently supported by using a DUP.S to broadcast the search
4088-
// elements, and more exotic cases like (nxv16i8, 5) could be supported by a
4089-
// sequence of SEL(DUP).
4080+
// Currently, we require the search vector to be 64-bit or 128-bit. In the
4081+
// future we can support more cases.
40904082
if (ST->hasSVE2() && ST->isSVEAvailable() &&
40914083
VT->getPrimitiveSizeInBits().getKnownMinValue() == 128 &&
40924084
(VT->getElementCount().getKnownMinValue() == 8 ||
40934085
VT->getElementCount().getKnownMinValue() == 16) &&
4094-
VT->getElementCount().getKnownMinValue() == SearchSize)
4086+
(VT->getElementCount().getKnownMinValue() == SearchSize ||
4087+
VT->getElementCount().getKnownMinValue() / 2 == SearchSize))
40954088
return true;
40964089
return false;
40974090
}

0 commit comments

Comments
 (0)