Skip to content

[GlobalISel][AArch64] Legalize G_INSERT_VECTOR_ELT for SVE #114310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,11 @@ inline LegalityPredicate typeIsNot(unsigned TypeIdx, LLT Type) {
LegalityPredicate
typePairInSet(unsigned TypeIdx0, unsigned TypeIdx1,
std::initializer_list<std::pair<LLT, LLT>> TypesInit);
/// True iff the given types for the given tuple of type indexes is one of the
/// specified type tuple.
LegalityPredicate
typeTupleInSet(unsigned TypeIdx0, unsigned TypeIdx1, unsigned TypeIdx2,
Comment on lines +276 to +279
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future could consider being fancy and using variadic templates

std::initializer_list<std::tuple<LLT, LLT, LLT>> TypesInit);
/// True iff the given types for the given pair of type indexes is one of the
/// specified type pairs.
LegalityPredicate typePairAndMemDescInSet(
Expand Down Expand Up @@ -504,6 +509,15 @@ class LegalizeRuleSet {
using namespace LegalityPredicates;
return actionIf(Action, typePairInSet(typeIdx(0), typeIdx(1), Types));
}

LegalizeRuleSet &
actionFor(LegalizeAction Action,
std::initializer_list<std::tuple<LLT, LLT, LLT>> Types) {
using namespace LegalityPredicates;
return actionIf(Action,
typeTupleInSet(typeIdx(0), typeIdx(1), typeIdx(2), Types));
}

/// Use the given action when type indexes 0 and 1 is any type pair in the
/// given list.
/// Action should be an action that requires mutation.
Expand Down Expand Up @@ -615,6 +629,12 @@ class LegalizeRuleSet {
return *this;
return actionFor(LegalizeAction::Legal, Types);
}
LegalizeRuleSet &
legalFor(bool Pred, std::initializer_list<std::tuple<LLT, LLT, LLT>> Types) {
if (!Pred)
return *this;
return actionFor(LegalizeAction::Legal, Types);
}
/// The instruction is legal when type index 0 is any type in the given list
/// and imm index 0 is anything.
LegalizeRuleSet &legalForTypeWithAnyImm(std::initializer_list<LLT> Types) {
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ LegalityPredicate LegalityPredicates::typePairInSet(
};
}

LegalityPredicate LegalityPredicates::typeTupleInSet(
unsigned TypeIdx0, unsigned TypeIdx1, unsigned TypeIdx2,
std::initializer_list<std::tuple<LLT, LLT, LLT>> TypesInit) {
SmallVector<std::tuple<LLT, LLT, LLT>, 4> Types = TypesInit;
return [=](const LegalityQuery &Query) {
std::tuple<LLT, LLT, LLT> Match = {
Query.Types[TypeIdx0], Query.Types[TypeIdx1], Query.Types[TypeIdx2]};
return llvm::is_contained(Types, Match);
};
}

LegalityPredicate LegalityPredicates::typePairAndMemDescInSet(
unsigned TypeIdx0, unsigned TypeIdx1, unsigned MMOIdx,
std::initializer_list<TypePairAndMemDesc> TypesAndMemDescInit) {
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,10 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
getActionDefinitionsBuilder(G_INSERT_VECTOR_ELT)
.legalIf(
typeInSet(0, {v16s8, v8s8, v8s16, v4s16, v4s32, v2s32, v2s64, v2p0}))
.legalFor(HasSVE, {{nxv16s8, s32, s64},
{nxv8s16, s32, s64},
{nxv4s32, s32, s64},
{nxv2s64, s64, s64}})
.moreElementsToNextPow2(0)
.widenVectorEltsToVectorMinSize(0, 64)
.clampNumElements(0, v8s8, v16s8)
Expand Down
51 changes: 43 additions & 8 deletions llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ bool matchREV(MachineInstr &MI, MachineRegisterInfo &MRI,
Register Dst = MI.getOperand(0).getReg();
Register Src = MI.getOperand(1).getReg();
LLT Ty = MRI.getType(Dst);
if (Ty.isScalableVector())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these needed? I don't believe there should be any shuffles on scalable types.

I think that goes for most of the changes in this file. Are the all speculative changes in reality are not necessary?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got asserts and went over almost all of them. The stack trace did not tell me where the crash was.

return false;
unsigned EltSize = Ty.getScalarSizeInBits();

// Element size for a rev cannot be 64.
Expand Down Expand Up @@ -196,7 +198,10 @@ bool matchTRN(MachineInstr &MI, MachineRegisterInfo &MRI,
unsigned WhichResult;
ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
Register Dst = MI.getOperand(0).getReg();
unsigned NumElts = MRI.getType(Dst).getNumElements();
LLT DstTy = MRI.getType(Dst);
if (DstTy.isScalableVector())
return false;
unsigned NumElts = DstTy.getNumElements();
if (!isTRNMask(ShuffleMask, NumElts, WhichResult))
return false;
unsigned Opc = (WhichResult == 0) ? AArch64::G_TRN1 : AArch64::G_TRN2;
Expand All @@ -217,7 +222,10 @@ bool matchUZP(MachineInstr &MI, MachineRegisterInfo &MRI,
unsigned WhichResult;
ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
Register Dst = MI.getOperand(0).getReg();
unsigned NumElts = MRI.getType(Dst).getNumElements();
LLT DstTy = MRI.getType(Dst);
if (DstTy.isScalableVector())
return false;
unsigned NumElts = DstTy.getNumElements();
if (!isUZPMask(ShuffleMask, NumElts, WhichResult))
return false;
unsigned Opc = (WhichResult == 0) ? AArch64::G_UZP1 : AArch64::G_UZP2;
Expand All @@ -233,7 +241,10 @@ bool matchZip(MachineInstr &MI, MachineRegisterInfo &MRI,
unsigned WhichResult;
ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
Register Dst = MI.getOperand(0).getReg();
unsigned NumElts = MRI.getType(Dst).getNumElements();
LLT DstTy = MRI.getType(Dst);
if (DstTy.isScalableVector())
return false;
unsigned NumElts = DstTy.getNumElements();
if (!isZIPMask(ShuffleMask, NumElts, WhichResult))
return false;
unsigned Opc = (WhichResult == 0) ? AArch64::G_ZIP1 : AArch64::G_ZIP2;
Expand Down Expand Up @@ -288,7 +299,10 @@ bool matchDupFromBuildVector(int Lane, MachineInstr &MI,
MachineRegisterInfo &MRI,
ShuffleVectorPseudo &MatchInfo) {
assert(Lane >= 0 && "Expected positive lane?");
int NumElements = MRI.getType(MI.getOperand(1).getReg()).getNumElements();
LLT Op1Ty = MRI.getType(MI.getOperand(1).getReg());
if (Op1Ty.isScalableVector())
return false;
int NumElements = Op1Ty.getNumElements();
// Test if the LHS is a BUILD_VECTOR. If it is, then we can just reference the
// lane's definition directly.
auto *BuildVecMI =
Expand Down Expand Up @@ -326,6 +340,8 @@ bool matchDup(MachineInstr &MI, MachineRegisterInfo &MRI,
// Check if an EXT instruction can handle the shuffle mask when the vector
// sources of the shuffle are the same.
bool isSingletonExtMask(ArrayRef<int> M, LLT Ty) {
if (Ty.isScalableVector())
return false;
unsigned NumElts = Ty.getNumElements();

// Assume that the first shuffle index is not UNDEF. Fail if it is.
Expand Down Expand Up @@ -357,12 +373,17 @@ bool matchEXT(MachineInstr &MI, MachineRegisterInfo &MRI,
assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
Register Dst = MI.getOperand(0).getReg();
LLT DstTy = MRI.getType(Dst);
if (DstTy.isScalableVector())
return false;
Register V1 = MI.getOperand(1).getReg();
Register V2 = MI.getOperand(2).getReg();
auto Mask = MI.getOperand(3).getShuffleMask();
uint64_t Imm;
auto ExtInfo = getExtMask(Mask, DstTy.getNumElements());
uint64_t ExtFactor = MRI.getType(V1).getScalarSizeInBits() / 8;
LLT V1Ty = MRI.getType(V1);
if (V1Ty.isScalableVector())
return false;
uint64_t ExtFactor = V1Ty.getScalarSizeInBits() / 8;

if (!ExtInfo) {
if (!getOpcodeDef<GImplicitDef>(V2, MRI) ||
Expand Down Expand Up @@ -423,6 +444,8 @@ void applyNonConstInsert(MachineInstr &MI, MachineRegisterInfo &MRI,

Register Offset = Insert.getIndexReg();
LLT VecTy = MRI.getType(Insert.getReg(0));
if (VecTy.isScalableVector())
return;
LLT EltTy = MRI.getType(Insert.getElementReg());
LLT IdxTy = MRI.getType(Insert.getIndexReg());

Expand Down Expand Up @@ -473,7 +496,10 @@ bool matchINS(MachineInstr &MI, MachineRegisterInfo &MRI,
assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
Register Dst = MI.getOperand(0).getReg();
int NumElts = MRI.getType(Dst).getNumElements();
LLT DstTy = MRI.getType(Dst);
if (DstTy.isScalableVector())
return false;
int NumElts = DstTy.getNumElements();
auto DstIsLeftAndDstLane = isINSMask(ShuffleMask, NumElts);
if (!DstIsLeftAndDstLane)
return false;
Expand Down Expand Up @@ -522,6 +548,8 @@ bool isVShiftRImm(Register Reg, MachineRegisterInfo &MRI, LLT Ty,
if (!Cst)
return false;
Cnt = *Cst;
if (Ty.isScalableVector())
return false;
int64_t ElementBits = Ty.getScalarSizeInBits();
return Cnt >= 1 && Cnt <= ElementBits;
}
Expand Down Expand Up @@ -698,6 +726,8 @@ bool matchDupLane(MachineInstr &MI, MachineRegisterInfo &MRI,
Register Src1Reg = MI.getOperand(1).getReg();
const LLT SrcTy = MRI.getType(Src1Reg);
const LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
if (SrcTy.isScalableVector())
return false;

auto LaneIdx = getSplatIndex(MI);
if (!LaneIdx)
Expand Down Expand Up @@ -774,6 +804,8 @@ bool matchScalarizeVectorUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI) {
auto &Unmerge = cast<GUnmerge>(MI);
Register Src1Reg = Unmerge.getReg(Unmerge.getNumOperands() - 1);
const LLT SrcTy = MRI.getType(Src1Reg);
if (SrcTy.isScalableVector())
return false;
if (SrcTy.getSizeInBits() != 128 && SrcTy.getSizeInBits() != 64)
return false;
return SrcTy.isVector() && !SrcTy.isScalable() &&
Expand Down Expand Up @@ -987,7 +1019,10 @@ bool matchLowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI,
if (!DstTy.isVector() || !ST.hasNEON())
return false;
Register LHS = MI.getOperand(2).getReg();
unsigned EltSize = MRI.getType(LHS).getScalarSizeInBits();
LLT LHSTy = MRI.getType(LHS);
if (LHSTy.isScalableVector())
return false;
unsigned EltSize = LHSTy.getScalarSizeInBits();
if (EltSize == 16 && !ST.hasFullFP16())
return false;
if (EltSize != 16 && EltSize != 32 && EltSize != 64)
Expand Down Expand Up @@ -1183,7 +1218,7 @@ bool matchExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI) {
MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);

if (DstTy.isVector()) {
if (DstTy.isFixedVector()) {
// If the source operands were EXTENDED before, then {U/S}MULL can be used
unsigned I1Opc = I1->getOpcode();
unsigned I2Opc = I2->getOpcode();
Expand Down
Loading
Loading