Skip to content

[AArch64][GlobalISel] Basic SVE and fadd #72976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ bool InstructionSelect::runOnMachineFunction(MachineFunction &MF) {
}

const LLT Ty = MRI.getType(VReg);
if (Ty.isValid() && Ty.getSizeInBits() > TRI.getRegSizeInBits(*RC)) {
if (Ty.isValid() &&
TypeSize::isKnownGT(Ty.getSizeInBits(), TRI.getRegSizeInBits(*RC))) {
reportGISelFailure(
MF, TPC, MORE, "gisel-select",
"VReg's low-level type and register class have different sizes", *MI);
Expand Down
12 changes: 7 additions & 5 deletions llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ bool AArch64GenRegisterBankInfo::checkValueMapImpl(unsigned Idx,
unsigned Size,
unsigned Offset) {
unsigned PartialMapBaseIdx = Idx - PartialMappingIdx::PMI_Min;
const ValueMapping &Map =
AArch64GenRegisterBankInfo::getValueMapping((PartialMappingIdx)FirstInBank, Size)[Offset];
const ValueMapping &Map = AArch64GenRegisterBankInfo::getValueMapping(
(PartialMappingIdx)FirstInBank, TypeSize::getFixed(Size))[Offset];
return Map.BreakDown == &PartMappings[PartialMapBaseIdx] &&
Map.NumBreakDowns == 1;
}
Expand Down Expand Up @@ -167,7 +167,7 @@ bool AArch64GenRegisterBankInfo::checkPartialMappingIdx(
}

unsigned AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(unsigned RBIdx,
unsigned Size) {
TypeSize Size) {
if (RBIdx == PMI_FirstGPR) {
if (Size <= 32)
return 0;
Expand All @@ -178,6 +178,8 @@ unsigned AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(unsigned RBIdx,
return -1;
}
if (RBIdx == PMI_FirstFPR) {
if (Size.isScalable())
return 3;
if (Size <= 16)
return 0;
if (Size <= 32)
Expand All @@ -197,7 +199,7 @@ unsigned AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(unsigned RBIdx,

const RegisterBankInfo::ValueMapping *
AArch64GenRegisterBankInfo::getValueMapping(PartialMappingIdx RBIdx,
unsigned Size) {
TypeSize Size) {
assert(RBIdx != PartialMappingIdx::PMI_None && "No mapping needed for that");
unsigned BaseIdxOffset = getRegBankBaseIdxOffset(RBIdx, Size);
if (BaseIdxOffset == -1u)
Expand All @@ -221,7 +223,7 @@ const AArch64GenRegisterBankInfo::PartialMappingIdx

const RegisterBankInfo::ValueMapping *
AArch64GenRegisterBankInfo::getCopyMapping(unsigned DstBankID,
unsigned SrcBankID, unsigned Size) {
unsigned SrcBankID, TypeSize Size) {
assert(DstBankID < AArch64::NumRegisterBanks && "Invalid bank ID");
assert(SrcBankID < AArch64::NumRegisterBanks && "Invalid bank ID");
PartialMappingIdx DstRBIdx = BankIDToCopyMapIdx[DstBankID];
Expand Down
11 changes: 8 additions & 3 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ static cl::opt<bool> EnableExtToTBL("aarch64-enable-ext-to-tbl", cl::Hidden,
static cl::opt<unsigned> MaxXors("aarch64-max-xors", cl::init(16), cl::Hidden,
cl::desc("Maximum of xors"));

cl::opt<bool> DisableSVEGISel(
"aarch64-disable-sve-gisel", cl::Hidden,
cl::desc("Enable / disable SVE scalable vectors in Global ISel"),
cl::init(true));

/// Value type used for condition codes.
static const MVT MVT_CC = MVT::i32;

Expand Down Expand Up @@ -25423,15 +25428,15 @@ bool AArch64TargetLowering::shouldLocalize(
}

bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
if (Inst.getType()->isScalableTy())
if (DisableSVEGISel && Inst.getType()->isScalableTy())
return true;

for (unsigned i = 0; i < Inst.getNumOperands(); ++i)
if (Inst.getOperand(i)->getType()->isScalableTy())
if (DisableSVEGISel && Inst.getOperand(i)->getType()->isScalableTy())
return true;

if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {
if (AI->getAllocatedType()->isScalableTy())
if (DisableSVEGISel && AI->getAllocatedType()->isScalableTy())
return true;
}

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/AArch64RegisterBanks.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def GPRRegBank : RegisterBank<"GPR", [XSeqPairsClass]>;

/// Floating Point/Vector Registers: B, H, S, D, Q.
def FPRRegBank : RegisterBank<"FPR", [QQQQ]>;
def FPRRegBank : RegisterBank<"FPR", [QQQQ, ZPR]>;

/// Conditional register: NZCV.
def CCRegBank : RegisterBank<"CC", [CCR]>;
14 changes: 8 additions & 6 deletions llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@

using namespace llvm;

extern cl::opt<bool> DisableSVEGISel;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems unnecessary. I don't think this should be a long lived option


AArch64CallLowering::AArch64CallLowering(const AArch64TargetLowering &TLI)
: CallLowering(&TLI) {}

Expand Down Expand Up @@ -387,8 +389,8 @@ bool AArch64CallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
// i1 is a special case because SDAG i1 true is naturally zero extended
// when widened using ANYEXT. We need to do it explicitly here.
auto &Flags = CurArgInfo.Flags[0];
if (MRI.getType(CurVReg).getSizeInBits() == 1 && !Flags.isSExt() &&
!Flags.isZExt()) {
if (MRI.getType(CurVReg).getSizeInBits() == TypeSize::getFixed(1) &&
!Flags.isSExt() && !Flags.isZExt()) {
CurVReg = MIRBuilder.buildZExt(LLT::scalar(8), CurVReg).getReg(0);
} else if (TLI.getNumRegistersForCallingConv(Ctx, CC, SplitEVTs[i]) ==
1) {
Expand Down Expand Up @@ -523,10 +525,10 @@ static void handleMustTailForwardedRegisters(MachineIRBuilder &MIRBuilder,

bool AArch64CallLowering::fallBackToDAGISel(const MachineFunction &MF) const {
auto &F = MF.getFunction();
if (F.getReturnType()->isScalableTy() ||
llvm::any_of(F.args(), [](const Argument &A) {
return A.getType()->isScalableTy();
}))
if (DisableSVEGISel && (F.getReturnType()->isScalableTy() ||
llvm::any_of(F.args(), [](const Argument &A) {
return A.getType()->isScalableTy();
})))
return true;
const auto &ST = MF.getSubtarget<AArch64Subtarget>();
if (!ST.hasNEON() || !ST.hasFPARMv8()) {
Expand Down
27 changes: 20 additions & 7 deletions llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,11 +595,12 @@ getRegClassForTypeOnBank(LLT Ty, const RegisterBank &RB,
/// Given a register bank, and size in bits, return the smallest register class
/// that can represent that combination.
static const TargetRegisterClass *
getMinClassForRegBank(const RegisterBank &RB, unsigned SizeInBits,
getMinClassForRegBank(const RegisterBank &RB, TypeSize SizeInBits,
bool GetAllRegSet = false) {
unsigned RegBankID = RB.getID();

if (RegBankID == AArch64::GPRRegBankID) {
assert(!SizeInBits.isScalable() && "Unexpected scalable register size");
if (SizeInBits <= 32)
return GetAllRegSet ? &AArch64::GPR32allRegClass
: &AArch64::GPR32RegClass;
Expand All @@ -611,6 +612,12 @@ getMinClassForRegBank(const RegisterBank &RB, unsigned SizeInBits,
}

if (RegBankID == AArch64::FPRRegBankID) {
if (SizeInBits.isScalable()) {
assert(SizeInBits == TypeSize::getScalable(128) &&
"Unexpected scalable register size");
return &AArch64::ZPRRegClass;
}

switch (SizeInBits) {
default:
return nullptr;
Expand Down Expand Up @@ -937,8 +944,8 @@ getRegClassesForCopy(MachineInstr &I, const TargetInstrInfo &TII,
Register SrcReg = I.getOperand(1).getReg();
const RegisterBank &DstRegBank = *RBI.getRegBank(DstReg, MRI, TRI);
const RegisterBank &SrcRegBank = *RBI.getRegBank(SrcReg, MRI, TRI);
unsigned DstSize = RBI.getSizeInBits(DstReg, MRI, TRI);
unsigned SrcSize = RBI.getSizeInBits(SrcReg, MRI, TRI);
TypeSize DstSize = RBI.getSizeInBits(DstReg, MRI, TRI);
TypeSize SrcSize = RBI.getSizeInBits(SrcReg, MRI, TRI);

// Special casing for cross-bank copies of s1s. We can technically represent
// a 1-bit value with any size of register. The minimum size for a GPR is 32
Expand All @@ -948,8 +955,9 @@ getRegClassesForCopy(MachineInstr &I, const TargetInstrInfo &TII,
// then we can pull it into the helpers that get the appropriate class for a
// register bank. Or make a new helper that carries along some constraint
// information.
if (SrcRegBank != DstRegBank && (DstSize == 1 && SrcSize == 1))
SrcSize = DstSize = 32;
if (SrcRegBank != DstRegBank &&
(DstSize == TypeSize::getFixed(1) && SrcSize == TypeSize::getFixed(1)))
SrcSize = DstSize = TypeSize::getFixed(32);

return {getMinClassForRegBank(SrcRegBank, SrcSize, true),
getMinClassForRegBank(DstRegBank, DstSize, true)};
Expand Down Expand Up @@ -1014,10 +1022,15 @@ static bool selectCopy(MachineInstr &I, const TargetInstrInfo &TII,
return false;
}

unsigned SrcSize = TRI.getRegSizeInBits(*SrcRC);
unsigned DstSize = TRI.getRegSizeInBits(*DstRC);
TypeSize SrcSize = TRI.getRegSizeInBits(*SrcRC);
TypeSize DstSize = TRI.getRegSizeInBits(*DstRC);
unsigned SubReg;

if (SrcSize.isScalable()) {
assert(DstSize.isScalable() && "Unhandled scalable copy");
return true;
}

// If the source bank doesn't support a subregister copy small enough,
// then we first need to copy to the destination bank.
if (getMinSizeForRegBank(SrcRegBank) > DstSize) {
Expand Down
6 changes: 5 additions & 1 deletion llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
const LLT v4s32 = LLT::fixed_vector(4, 32);
const LLT v2s64 = LLT::fixed_vector(2, 64);
const LLT v2p0 = LLT::fixed_vector(2, p0);
const LLT nxv8s16 = LLT::scalable_vector(8, 16);
const LLT nxv4s32 = LLT::scalable_vector(4, 32);
const LLT nxv2s64 = LLT::scalable_vector(2, 64);

std::initializer_list<LLT> PackedVectorAllTypeList = {/* Begin 128bit types */
v16s8, v8s16, v4s32,
Expand Down Expand Up @@ -238,7 +241,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
G_FMAXIMUM, G_FMINIMUM, G_FCEIL, G_FFLOOR,
G_FRINT, G_FNEARBYINT, G_INTRINSIC_TRUNC,
G_INTRINSIC_ROUND, G_INTRINSIC_ROUNDEVEN})
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move G_FADD into a separate builder and slowly move/legalize/merge the two builders back into one? G_INTRINSIC_TRUNC is legal for SVE?!?

.legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64})
.legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64, nxv8s16, nxv4s32,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HasSVE?

nxv2s64})
.legalIf([=](const LegalityQuery &Query) {
const auto &Ty = Query.Types[0];
return (Ty == v8s16 || Ty == v4s16) && HasFP16;
Expand Down
64 changes: 33 additions & 31 deletions llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,18 @@ AArch64RegisterBankInfo::AArch64RegisterBankInfo(
unsigned PartialMapSrcIdx = PMI_##RBNameSrc##Size - PMI_Min; \
(void)PartialMapDstIdx; \
(void)PartialMapSrcIdx; \
const ValueMapping *Map = getCopyMapping( \
AArch64::RBNameDst##RegBankID, AArch64::RBNameSrc##RegBankID, Size); \
const ValueMapping *Map = getCopyMapping(AArch64::RBNameDst##RegBankID, \
AArch64::RBNameSrc##RegBankID, \
TypeSize::getFixed(Size)); \
(void)Map; \
assert(Map[0].BreakDown == \
&AArch64GenRegisterBankInfo::PartMappings[PartialMapDstIdx] && \
Map[0].NumBreakDowns == 1 && #RBNameDst #Size \
" Dst is incorrectly initialized"); \
Map[0].NumBreakDowns == 1 && \
#RBNameDst #Size " Dst is incorrectly initialized"); \
assert(Map[1].BreakDown == \
&AArch64GenRegisterBankInfo::PartMappings[PartialMapSrcIdx] && \
Map[1].NumBreakDowns == 1 && #RBNameSrc #Size \
" Src is incorrectly initialized"); \
Map[1].NumBreakDowns == 1 && \
#RBNameSrc #Size " Src is incorrectly initialized"); \
\
} while (false)

Expand Down Expand Up @@ -256,6 +257,9 @@ AArch64RegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
case AArch64::QQRegClassID:
case AArch64::QQQRegClassID:
case AArch64::QQQQRegClassID:
case AArch64::ZPR_3bRegClassID:
case AArch64::ZPR_4bRegClassID:
case AArch64::ZPRRegClassID:
return getRegBank(AArch64::FPRRegBankID);
case AArch64::GPR32commonRegClassID:
case AArch64::GPR32RegClassID:
Expand Down Expand Up @@ -300,8 +304,8 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
case TargetOpcode::G_OR: {
// 32 and 64-bit or can be mapped on either FPR or
// GPR for the same cost.
unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
if (Size != 32 && Size != 64)
TypeSize Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
if (Size != TypeSize::getFixed(32) && Size != TypeSize::getFixed(64))
break;

// If the instruction has any implicit-defs or uses,
Expand All @@ -321,8 +325,8 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
return AltMappings;
}
case TargetOpcode::G_BITCAST: {
unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
if (Size != 32 && Size != 64)
TypeSize Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
if (Size != TypeSize::getFixed(32) && Size != TypeSize::getFixed(64))
break;

// If the instruction has any implicit-defs or uses,
Expand All @@ -341,16 +345,12 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
/*NumOperands*/ 2);
const InstructionMapping &GPRToFPRMapping = getInstructionMapping(
/*ID*/ 3,
/*Cost*/
copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank,
TypeSize::getFixed(Size)),
/*Cost*/ copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank, Size),
getCopyMapping(AArch64::FPRRegBankID, AArch64::GPRRegBankID, Size),
/*NumOperands*/ 2);
const InstructionMapping &FPRToGPRMapping = getInstructionMapping(
/*ID*/ 3,
/*Cost*/
copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank,
TypeSize::getFixed(Size)),
/*Cost*/ copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank, Size),
getCopyMapping(AArch64::GPRRegBankID, AArch64::FPRRegBankID, Size),
/*NumOperands*/ 2);

Expand All @@ -361,8 +361,8 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
return AltMappings;
}
case TargetOpcode::G_LOAD: {
unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
if (Size != 64)
TypeSize Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
if (Size != TypeSize::getFixed(64))
break;

// If the instruction has any implicit-defs or uses,
Expand All @@ -373,15 +373,17 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
InstructionMappings AltMappings;
const InstructionMapping &GPRMapping = getInstructionMapping(
/*ID*/ 1, /*Cost*/ 1,
getOperandsMapping({getValueMapping(PMI_FirstGPR, Size),
// Addresses are GPR 64-bit.
getValueMapping(PMI_FirstGPR, 64)}),
getOperandsMapping(
{getValueMapping(PMI_FirstGPR, Size),
// Addresses are GPR 64-bit.
getValueMapping(PMI_FirstGPR, TypeSize::getFixed(64))}),
/*NumOperands*/ 2);
const InstructionMapping &FPRMapping = getInstructionMapping(
/*ID*/ 2, /*Cost*/ 1,
getOperandsMapping({getValueMapping(PMI_FirstFPR, Size),
// Addresses are GPR 64-bit.
getValueMapping(PMI_FirstGPR, 64)}),
getOperandsMapping(
{getValueMapping(PMI_FirstFPR, Size),
// Addresses are GPR 64-bit.
getValueMapping(PMI_FirstGPR, TypeSize::getFixed(64))}),
/*NumOperands*/ 2);

AltMappings.push_back(&GPRMapping);
Expand Down Expand Up @@ -459,7 +461,7 @@ AArch64RegisterBankInfo::getSameKindOfOperandsMapping(
"This code is for instructions with 3 or less operands");

LLT Ty = MRI.getType(MI.getOperand(0).getReg());
unsigned Size = Ty.getSizeInBits();
TypeSize Size = Ty.getSizeInBits();
bool IsFPR = Ty.isVector() || isPreISelGenericFloatingPointOpcode(Opc);

PartialMappingIdx RBIdx = IsFPR ? PMI_FirstFPR : PMI_FirstGPR;
Expand Down Expand Up @@ -719,9 +721,9 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
// If both RB are null that means both registers are generic.
// We shouldn't be here.
assert(DstRB && SrcRB && "Both RegBank were nullptr");
unsigned Size = getSizeInBits(DstReg, MRI, TRI);
TypeSize Size = getSizeInBits(DstReg, MRI, TRI);
return getInstructionMapping(
DefaultMappingID, copyCost(*DstRB, *SrcRB, TypeSize::getFixed(Size)),
DefaultMappingID, copyCost(*DstRB, *SrcRB, Size),
getCopyMapping(DstRB->getID(), SrcRB->getID(), Size),
// We only care about the mapping of the destination.
/*NumOperands*/ 1);
Expand All @@ -732,15 +734,15 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
case TargetOpcode::G_BITCAST: {
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
LLT SrcTy = MRI.getType(MI.getOperand(1).getReg());
unsigned Size = DstTy.getSizeInBits();
TypeSize Size = DstTy.getSizeInBits();
bool DstIsGPR = !DstTy.isVector() && DstTy.getSizeInBits() <= 64;
bool SrcIsGPR = !SrcTy.isVector() && SrcTy.getSizeInBits() <= 64;
const RegisterBank &DstRB =
DstIsGPR ? AArch64::GPRRegBank : AArch64::FPRRegBank;
const RegisterBank &SrcRB =
SrcIsGPR ? AArch64::GPRRegBank : AArch64::FPRRegBank;
return getInstructionMapping(
DefaultMappingID, copyCost(DstRB, SrcRB, TypeSize::getFixed(Size)),
DefaultMappingID, copyCost(DstRB, SrcRB, Size),
getCopyMapping(DstRB.getID(), SrcRB.getID(), Size),
// We only care about the mapping of the destination for COPY.
/*NumOperands*/ Opc == TargetOpcode::G_BITCAST ? 2 : 1);
Expand All @@ -752,7 +754,7 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
unsigned NumOperands = MI.getNumOperands();

// Track the size and bank of each register. We don't do partial mappings.
SmallVector<unsigned, 4> OpSize(NumOperands);
SmallVector<TypeSize, 4> OpSize(NumOperands, TypeSize::getFixed(0));
SmallVector<PartialMappingIdx, 4> OpRegBankIdx(NumOperands);
for (unsigned Idx = 0; Idx < NumOperands; ++Idx) {
auto &MO = MI.getOperand(Idx);
Expand Down Expand Up @@ -833,7 +835,7 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
Cost = copyCost(
*AArch64GenRegisterBankInfo::PartMappings[OpRegBankIdx[0]].RegBank,
*AArch64GenRegisterBankInfo::PartMappings[OpRegBankIdx[1]].RegBank,
TypeSize::getFixed(OpSize[0]));
OpSize[0]);
break;
case TargetOpcode::G_LOAD: {
// Loading in vector unit is slightly more expensive.
Expand Down
Loading