Skip to content

[AArch64][SVE2] Enable dynamic shuffle for fixed length types. #72490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 49 additions & 11 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26798,26 +26798,47 @@ static SDValue GenerateFixedLengthSVETBL(SDValue Op, SDValue Op1, SDValue Op2,

// Ignore two operands if no SVE2 or all index numbers couldn't
// be represented.
if (!IsSingleOp && (!Subtarget.hasSVE2() || MinSVESize != MaxSVESize))
if (!IsSingleOp && !Subtarget.hasSVE2())
return SDValue();

EVT VTOp1 = Op.getOperand(0).getValueType();
unsigned BitsPerElt = VTOp1.getVectorElementType().getSizeInBits();
unsigned IndexLen = MinSVESize / BitsPerElt;
unsigned ElementsPerVectorReg = VTOp1.getVectorNumElements();
uint64_t MaxOffset = APInt(BitsPerElt, -1, false).getZExtValue();
EVT MaskEltType = VTOp1.getVectorElementType().changeTypeToInteger();
EVT MaskType = EVT::getVectorVT(*DAG.getContext(), MaskEltType, IndexLen);
bool MinMaxEqual = (MinSVESize == MaxSVESize);
assert(ElementsPerVectorReg <= IndexLen && ShuffleMask.size() <= IndexLen &&
"Incorrectly legalised shuffle operation");

SmallVector<SDValue, 8> TBLMask;
// If MinSVESize is not equal to MaxSVESize then we need to know which
// TBL mask element needs adjustment.
SmallVector<SDValue, 8> AddRuntimeVLMask;

// Bail out for 8-bits element types, because with 2048-bit SVE register
// size 8 bits is only sufficient to index into the first source vector.
if (!IsSingleOp && !MinMaxEqual && BitsPerElt == 8)
return SDValue();

for (int Index : ShuffleMask) {
// Handling poison index value.
if (Index < 0)
Index = 0;
// If we refer to the second operand then we have to add elements
// number in hardware register minus number of elements in a type.
if ((unsigned)Index >= ElementsPerVectorReg)
Index += IndexLen - ElementsPerVectorReg;
// If the mask refers to elements in the second operand, then we have to
// offset the index by the number of elements in a vector. If this is number
// is not known at compile-time, we need to maintain a mask with 'VL' values
// to add at runtime.
if ((unsigned)Index >= ElementsPerVectorReg) {
if (MinMaxEqual) {
Index += IndexLen - ElementsPerVectorReg;
} else {
Index = Index - ElementsPerVectorReg;
AddRuntimeVLMask.push_back(DAG.getConstant(1, DL, MVT::i64));
}
} else if (!MinMaxEqual)
AddRuntimeVLMask.push_back(DAG.getConstant(0, DL, MVT::i64));
// For 8-bit elements and 1024-bit SVE registers and MaxOffset equals
// to 255, this might point to the last element of in the second operand
// of the shufflevector, thus we are rejecting this transform.
Expand All @@ -26830,11 +26851,12 @@ static SDValue GenerateFixedLengthSVETBL(SDValue Op, SDValue Op1, SDValue Op2,
// value where it would perform first lane duplication for out of
// index elements. For i8 elements an out-of-range index could be a valid
// for 2048-bit vector register size.
for (unsigned i = 0; i < IndexLen - ElementsPerVectorReg; ++i)
for (unsigned i = 0; i < IndexLen - ElementsPerVectorReg; ++i) {
TBLMask.push_back(DAG.getConstant((int)MaxOffset, DL, MVT::i64));
if (!MinMaxEqual)
AddRuntimeVLMask.push_back(DAG.getConstant(0, DL, MVT::i64));
}

EVT MaskEltType = EVT::getIntegerVT(*DAG.getContext(), BitsPerElt);
EVT MaskType = EVT::getVectorVT(*DAG.getContext(), MaskEltType, IndexLen);
EVT MaskContainerVT = getContainerForFixedLengthVector(DAG, MaskType);
SDValue VecMask =
DAG.getBuildVector(MaskType, DL, ArrayRef(TBLMask.data(), IndexLen));
Expand All @@ -26846,13 +26868,29 @@ static SDValue GenerateFixedLengthSVETBL(SDValue Op, SDValue Op1, SDValue Op2,
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT,
DAG.getConstant(Intrinsic::aarch64_sve_tbl, DL, MVT::i32),
Op1, SVEMask);
else if (Subtarget.hasSVE2())
else if (Subtarget.hasSVE2()) {
if (!MinMaxEqual) {
unsigned MinNumElts = AArch64::SVEBitsPerBlock / BitsPerElt;
SDValue VScale = (BitsPerElt == 64)
? DAG.getVScale(DL, MVT::i64, APInt(64, MinNumElts))
: DAG.getVScale(DL, MVT::i32, APInt(32, MinNumElts));
SDValue VecMask =
DAG.getBuildVector(MaskType, DL, ArrayRef(TBLMask.data(), IndexLen));
SDValue MulByMask = DAG.getNode(
ISD::MUL, DL, MaskType,
DAG.getNode(ISD::SPLAT_VECTOR, DL, MaskType, VScale),
DAG.getBuildVector(MaskType, DL,
ArrayRef(AddRuntimeVLMask.data(), IndexLen)));
SDValue UpdatedVecMask =
DAG.getNode(ISD::ADD, DL, MaskType, VecMask, MulByMask);
SVEMask = convertToScalableVector(
DAG, getContainerForFixedLengthVector(DAG, MaskType), UpdatedVecMask);
}
Shuffle =
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT,
DAG.getConstant(Intrinsic::aarch64_sve_tbl2, DL, MVT::i32),
Op1, Op2, SVEMask);
else
llvm_unreachable("Cannot lower shuffle without SVE2 TBL");
}
Shuffle = convertFromScalableVector(DAG, VT, Shuffle);
return DAG.getNode(ISD::BITCAST, DL, Op.getValueType(), Shuffle);
}
Expand Down
Loading