Skip to content

[AArch64] Add @llvm.experimental.vector.match #101974

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions llvm/docs/LangRef.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20089,6 +20089,44 @@ are undefined.
}


'``llvm.experimental.vector.match.*``' Intrinsic
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""

This is an overloaded intrinsic.

::

declare <<n> x i1> @llvm.experimental.vector.match(<<n> x <ty>> %op1, <<m> x <ty>> %op2, <<n> x i1> %mask)
declare <vscale x <n> x i1> @llvm.experimental.vector.match(<vscale x <n> x <ty>> %op1, <<m> x <ty>> %op2, <vscale x <n> x i1> %mask)

Overview:
"""""""""

Find active elements of the first argument matching any elements of the second.

Arguments:
""""""""""

The first argument is the search vector, the second argument the vector of
elements we are searching for (i.e. for which we consider a match successful),
and the third argument is a mask that controls which elements of the first
argument are active. The first two arguments must be vectors of matching
integer element types. The first and third arguments and the result type must
have matching element counts (fixed or scalable). The second argument must be a
fixed vector, but its length may be different from the remaining arguments.

Semantics:
""""""""""

The '``llvm.experimental.vector.match``' intrinsic compares each active element
in the first argument against the elements of the second argument, placing
``1`` in the corresponding element of the output vector if any equality
comparison is successful, and ``0`` otherwise. Inactive elements in the mask
are set to ``0`` in the output.

Matrix Intrinsics
-----------------

Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,13 @@ class TargetLoweringBase {
bool ZeroIsPoison,
const ConstantRange *VScaleRange) const;

/// Return true if the @llvm.experimental.vector.match intrinsic should be
/// expanded for vector type `VT' and search size `SearchSize' using generic
/// code in SelectionDAGBuilder.
virtual bool shouldExpandVectorMatch(EVT VT, unsigned SearchSize) const {
return true;
}

// Return true if op(vecreduce(x), vecreduce(y)) should be reassociated to
// vecreduce(op(x, y)) for the reduction opcode RedOpc.
virtual bool shouldReassociateReduction(unsigned RedOpc, EVT VT) const {
Expand Down
8 changes: 8 additions & 0 deletions llvm/include/llvm/IR/Intrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -1920,6 +1920,14 @@ def int_experimental_vector_histogram_add : DefaultAttrsIntrinsic<[],
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], // Mask
[ IntrArgMemOnly ]>;

// Experimental match
def int_experimental_vector_match : DefaultAttrsIntrinsic<
[ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ],
[ llvm_anyvector_ty,
llvm_anyvector_ty,
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ], // Mask
[ IntrNoMem, IntrNoSync, IntrWillReturn ]>;

// Operators
let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn] in {
// Integer arithmetic
Expand Down
30 changes: 30 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8176,6 +8176,36 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
DAG.getNode(ISD::EXTRACT_SUBVECTOR, sdl, ResultVT, Vec, Index));
return;
}
case Intrinsic::experimental_vector_match: {
SDValue Op1 = getValue(I.getOperand(0));
SDValue Op2 = getValue(I.getOperand(1));
SDValue Mask = getValue(I.getOperand(2));
EVT Op1VT = Op1.getValueType();
EVT Op2VT = Op2.getValueType();
EVT ResVT = Mask.getValueType();
unsigned SearchSize = Op2VT.getVectorNumElements();

// If the target has native support for this vector match operation, lower
// the intrinsic untouched; otherwise, expand it below.
if (!TLI.shouldExpandVectorMatch(Op1VT, SearchSize)) {
visitTargetIntrinsic(I, Intrinsic);
return;
}

SDValue Ret = DAG.getConstant(0, sdl, ResVT);

for (unsigned i = 0; i < SearchSize; ++i) {
SDValue Op2Elem = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, sdl,
Op2VT.getVectorElementType(), Op2,
DAG.getVectorIdxConstant(i, sdl));
SDValue Splat = DAG.getNode(ISD::SPLAT_VECTOR, sdl, Op1VT, Op2Elem);
SDValue Cmp = DAG.getSetCC(sdl, ResVT, Op1, Splat, ISD::SETEQ);
Ret = DAG.getNode(ISD::OR, sdl, ResVT, Ret, Cmp);
}

setValue(&I, DAG.getNode(ISD::AND, sdl, ResVT, Ret, Mask));
return;
}
case Intrinsic::vector_reverse:
visitVectorReverse(I);
return;
Expand Down
25 changes: 25 additions & 0 deletions llvm/lib/IR/Verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6150,6 +6150,31 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
&Call);
break;
}
case Intrinsic::experimental_vector_match: {
Value *Op1 = Call.getArgOperand(0);
Value *Op2 = Call.getArgOperand(1);
Value *Mask = Call.getArgOperand(2);

VectorType *Op1Ty = dyn_cast<VectorType>(Op1->getType());
VectorType *Op2Ty = dyn_cast<VectorType>(Op2->getType());
VectorType *MaskTy = dyn_cast<VectorType>(Mask->getType());

Check(Op1Ty && Op2Ty && MaskTy, "Operands must be vectors.", &Call);
Check(isa<FixedVectorType>(Op2Ty),
"Second operand must be a fixed length vector.", &Call);
Check(Op1Ty->getElementType()->isIntegerTy(),
"First operand must be a vector of integers.", &Call);
Check(Op1Ty->getElementType() == Op2Ty->getElementType(),
"First two operands must have the same element type.", &Call);
Check(Op1Ty->getElementCount() == MaskTy->getElementCount(),
"First operand and mask must have the same number of elements.",
&Call);
Check(MaskTy->getElementType()->isIntegerTy(1),
"Mask must be a vector of i1's.", &Call);
Check(Call.getType() == MaskTy, "Return type must match the mask type.",
&Call);
break;
}
case Intrinsic::vector_insert: {
Value *Vec = Call.getArgOperand(0);
Value *SubVec = Call.getArgOperand(1);
Expand Down
83 changes: 83 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2057,6 +2057,19 @@ bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
VT != MVT::v4i1 && VT != MVT::v2i1;
}

bool AArch64TargetLowering::shouldExpandVectorMatch(EVT VT,
unsigned SearchSize) const {
// MATCH is SVE2 and only available in non-streaming mode.
if (!Subtarget->hasSVE2() || !Subtarget->isSVEAvailable())
return true;
// Furthermore, we can only use it for 8-bit or 16-bit elements.
if (VT == MVT::nxv8i16 || VT == MVT::v8i16)
return SearchSize != 8;
if (VT == MVT::nxv16i8 || VT == MVT::v16i8 || VT == MVT::v8i8)
return SearchSize != 8 && SearchSize != 16;
return true;
}

void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");

Expand Down Expand Up @@ -5778,6 +5791,72 @@ SDValue LowerSMELdrStr(SDValue N, SelectionDAG &DAG, bool IsLoad) {
DAG.getTargetConstant(ImmAddend, DL, MVT::i32)});
}

SDValue LowerVectorMatch(SDValue Op, SelectionDAG &DAG) {
SDLoc dl(Op);
SDValue ID =
DAG.getTargetConstant(Intrinsic::aarch64_sve_match, dl, MVT::i64);

auto Op1 = Op.getOperand(1);
auto Op2 = Op.getOperand(2);
auto Mask = Op.getOperand(3);

EVT Op1VT = Op1.getValueType();
EVT Op2VT = Op2.getValueType();
EVT ResVT = Op.getValueType();

assert((Op1VT.getVectorElementType() == MVT::i8 ||
Op1VT.getVectorElementType() == MVT::i16) &&
"Expected 8-bit or 16-bit characters.");

// Scalable vector type used to wrap operands.
// A single container is enough for both operands because ultimately the
// operands will have to be wrapped to the same type (nxv16i8 or nxv8i16).
EVT OpContainerVT = Op1VT.isScalableVector()
? Op1VT
: getContainerForFixedLengthVector(DAG, Op1VT);

if (Op2VT.is128BitVector()) {
// If Op2 is a full 128-bit vector, wrap it trivially in a scalable vector.
Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
// Further, if the result is scalable, broadcast Op2 to a full SVE register.
if (ResVT.isScalableVector())
Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
DAG.getTargetConstant(0, dl, MVT::i64));
} else {
// If Op2 is not a full 128-bit vector, we always need to broadcast it.
unsigned Op2BitWidth = Op2VT.getFixedSizeInBits();
MVT Op2IntVT = MVT::getIntegerVT(Op2BitWidth);
EVT Op2PromotedVT = getPackedSVEVectorVT(Op2IntVT);
Op2 = DAG.getBitcast(MVT::getVectorVT(Op2IntVT, 1), Op2);
Op2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op2IntVT, Op2,
DAG.getConstant(0, dl, MVT::i64));
Op2 = DAG.getSplatVector(Op2PromotedVT, dl, Op2);
Op2 = DAG.getBitcast(OpContainerVT, Op2);
}

// If the result is scalable, we just need to carry out the MATCH.
if (ResVT.isScalableVector())
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, ResVT, ID, Mask, Op1, Op2);

// If the result is fixed, we can still use MATCH but we need to wrap the
// first operand and the mask in scalable vectors before doing so.

// Wrap the operands.
Op1 = convertToScalableVector(DAG, OpContainerVT, Op1);
Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, Op1VT, Mask);
Mask = convertFixedMaskToScalableVector(Mask, DAG);

// Carry out the match.
SDValue Match = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, Mask.getValueType(),
ID, Mask, Op1, Op2);

// Extract and promote the match result (nxv16i1/nxv8i1) to ResVT
// (v16i8/v8i8).
Match = DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match);
Match = convertFromScalableVector(DAG, Op1VT, Match);
return DAG.getNode(ISD::TRUNCATE, dl, ResVT, Match);
}

SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
SelectionDAG &DAG) const {
unsigned IntNo = Op.getConstantOperandVal(1);
Expand Down Expand Up @@ -6381,6 +6460,9 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
DAG.getNode(AArch64ISD::CTTZ_ELTS, dl, MVT::i64, CttzOp);
return DAG.getZExtOrTrunc(NewCttzElts, dl, Op.getValueType());
}
case Intrinsic::experimental_vector_match: {
return LowerVectorMatch(Op, DAG);
}
}
}

Expand Down Expand Up @@ -27100,6 +27182,7 @@ void AArch64TargetLowering::ReplaceNodeResults(
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
return;
}
case Intrinsic::experimental_vector_match:
case Intrinsic::get_active_lane_mask: {
if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1)
return;
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,8 @@ class AArch64TargetLowering : public TargetLowering {

bool shouldExpandCttzElements(EVT VT) const override;

bool shouldExpandVectorMatch(EVT VT, unsigned SearchSize) const override;

/// If a change in streaming mode is required on entry to/return from a
/// function call it emits and returns the corresponding SMSTART or SMSTOP
/// node. \p Condition should be one of the enum values from
Expand Down
Loading
Loading