Skip to content

[AArch64][GISel] Translate legal SVE formal arguments and select COPY for SVE #95236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ bool AArch64GenRegisterBankInfo::checkValueMapImpl(unsigned Idx,
unsigned Offset) {
unsigned PartialMapBaseIdx = Idx - PartialMappingIdx::PMI_Min;
const ValueMapping &Map =
AArch64GenRegisterBankInfo::getValueMapping((PartialMappingIdx)FirstInBank, Size)[Offset];
AArch64GenRegisterBankInfo::getValueMapping(
(PartialMappingIdx)FirstInBank,
TypeSize::getFixed(Size))[Offset];
return Map.BreakDown == &PartMappings[PartialMapBaseIdx] &&
Map.NumBreakDowns == 1;
}
Expand Down Expand Up @@ -167,7 +169,7 @@ bool AArch64GenRegisterBankInfo::checkPartialMappingIdx(
}

unsigned AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(unsigned RBIdx,
unsigned Size) {
TypeSize Size) {
if (RBIdx == PMI_FirstGPR) {
if (Size <= 32)
return 0;
Expand All @@ -178,17 +180,20 @@ unsigned AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(unsigned RBIdx,
return -1;
}
if (RBIdx == PMI_FirstFPR) {
if (Size <= 16)
const unsigned MinSize = Size.getKnownMinValue();
assert(!Size.isScalable() || MinSize >= 128
&& "Scalable vector types should have size of at least 128 bits");
if (MinSize <= 16)
return 0;
if (Size <= 32)
if (MinSize <= 32)
return 1;
if (Size <= 64)
if (MinSize <= 64)
return 2;
if (Size <= 128)
if (MinSize <= 128)
return 3;
if (Size <= 256)
if (MinSize <= 256)
return 4;
if (Size <= 512)
if (MinSize <= 512)
return 5;
return -1;
}
Expand All @@ -197,7 +202,7 @@ unsigned AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(unsigned RBIdx,

const RegisterBankInfo::ValueMapping *
AArch64GenRegisterBankInfo::getValueMapping(PartialMappingIdx RBIdx,
unsigned Size) {
const TypeSize Size) {
assert(RBIdx != PartialMappingIdx::PMI_None && "No mapping needed for that");
unsigned BaseIdxOffset = getRegBankBaseIdxOffset(RBIdx, Size);
if (BaseIdxOffset == -1u)
Expand All @@ -221,7 +226,8 @@ const AArch64GenRegisterBankInfo::PartialMappingIdx

const RegisterBankInfo::ValueMapping *
AArch64GenRegisterBankInfo::getCopyMapping(unsigned DstBankID,
unsigned SrcBankID, unsigned Size) {
unsigned SrcBankID,
const TypeSize Size) {
assert(DstBankID < AArch64::NumRegisterBanks && "Invalid bank ID");
assert(SrcBankID < AArch64::NumRegisterBanks && "Invalid bank ID");
PartialMappingIdx DstRBIdx = BankIDToCopyMapIdx[DstBankID];
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ static cl::opt<unsigned> MaxXors("aarch64-max-xors", cl::init(16), cl::Hidden,
// scalable vector types for all instruction, even if SVE is not yet supported
// with some instructions.
// See [AArch64TargetLowering::fallbackToDAGISel] for implementation details.
static cl::opt<bool> EnableSVEGISel(
cl::opt<bool> EnableSVEGISel(
"aarch64-enable-gisel-sve", cl::Hidden,
cl::desc("Enable / disable SVE scalable vectors in Global ISel"),
cl::init(false));
Expand Down
10 changes: 6 additions & 4 deletions llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
using namespace llvm;
using namespace AArch64GISelUtils;

extern cl::opt<bool> EnableSVEGISel;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't need to spread. I don't like having fallBackToDAGISel at all, but I see we have 2 of them. Why not just put them both in the same place?

Copy link
Member Author

@Him188 Him188 Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AArch64CallLowering::fallBackToDAGISel works on MachineFunction as an early path before visiting each instruction.
AArch64TargetLowering::fallBackToDAGISel works on each Instruction.

FastISel also uses AArch64TargetLowering::fallBackToDAGISel, making it more complex to merge the logic.

This patch does not completely support SVE formal arguments, especially for predicate registers <vscale x 16 x i1>.
The option is needed until we support them.


AArch64CallLowering::AArch64CallLowering(const AArch64TargetLowering &TLI)
: CallLowering(&TLI) {}

Expand Down Expand Up @@ -525,10 +527,10 @@ static void handleMustTailForwardedRegisters(MachineIRBuilder &MIRBuilder,

bool AArch64CallLowering::fallBackToDAGISel(const MachineFunction &MF) const {
auto &F = MF.getFunction();
if (F.getReturnType()->isScalableTy() ||
llvm::any_of(F.args(), [](const Argument &A) {
return A.getType()->isScalableTy();
}))
if (!EnableSVEGISel && (F.getReturnType()->isScalableTy() ||

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are scalable return types enabled now?

Copy link
Member Author

@Him188 Him188 Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return is not supported yet. I'm following the same idea from #92130 (comment) to enable SVE only with a debug option

llvm::any_of(F.args(), [](const Argument &A) {
return A.getType()->isScalableTy();
})))
return true;
const auto &ST = MF.getSubtarget<AArch64Subtarget>();
if (!ST.hasNEON() || !ST.hasFPARMv8()) {
Expand Down
19 changes: 13 additions & 6 deletions llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,8 +597,14 @@ getRegClassForTypeOnBank(LLT Ty, const RegisterBank &RB,
/// Given a register bank, and size in bits, return the smallest register class
/// that can represent that combination.
static const TargetRegisterClass *
getMinClassForRegBank(const RegisterBank &RB, unsigned SizeInBits,
getMinClassForRegBank(const RegisterBank &RB, TypeSize SizeInBits,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable name "SizeInBits" is not appropriate for Type "TypeSize". Please have something else.

Copy link
Member Author

@Him188 Him188 Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine, as it's obtained from TypeSize RegisterBankInfo::getSizeInBits().

There are both TypeSize LLT::getSizeInBytes() and TypeSize LLT::getSizeInBits() so TypeSize itself does not have a unit. It will be better if we specify the unit in a function parameter.

bool GetAllRegSet = false) {
if (SizeInBits.isScalable()) {
assert(RB.getID() == AArch64::FPRRegBankID &&
"Expected FPR regbank for scalable type size");
return &AArch64::ZPRRegClass;
}

unsigned RegBankID = RB.getID();

if (RegBankID == AArch64::GPRRegBankID) {
Expand Down Expand Up @@ -939,8 +945,9 @@ getRegClassesForCopy(MachineInstr &I, const TargetInstrInfo &TII,
Register SrcReg = I.getOperand(1).getReg();
const RegisterBank &DstRegBank = *RBI.getRegBank(DstReg, MRI, TRI);
const RegisterBank &SrcRegBank = *RBI.getRegBank(SrcReg, MRI, TRI);
unsigned DstSize = RBI.getSizeInBits(DstReg, MRI, TRI);
unsigned SrcSize = RBI.getSizeInBits(SrcReg, MRI, TRI);

TypeSize DstSize = RBI.getSizeInBits(DstReg, MRI, TRI);
TypeSize SrcSize = RBI.getSizeInBits(SrcReg, MRI, TRI);

// Special casing for cross-bank copies of s1s. We can technically represent
// a 1-bit value with any size of register. The minimum size for a GPR is 32
Expand All @@ -951,7 +958,7 @@ getRegClassesForCopy(MachineInstr &I, const TargetInstrInfo &TII,
// register bank. Or make a new helper that carries along some constraint
// information.
if (SrcRegBank != DstRegBank && (DstSize == 1 && SrcSize == 1))
SrcSize = DstSize = 32;
SrcSize = DstSize = TypeSize::getFixed(32);

return {getMinClassForRegBank(SrcRegBank, SrcSize, true),
getMinClassForRegBank(DstRegBank, DstSize, true)};
Expand Down Expand Up @@ -1016,8 +1023,8 @@ static bool selectCopy(MachineInstr &I, const TargetInstrInfo &TII,
return false;
}

unsigned SrcSize = TRI.getRegSizeInBits(*SrcRC);
unsigned DstSize = TRI.getRegSizeInBits(*DstRC);
const TypeSize SrcSize = TRI.getRegSizeInBits(*SrcRC);
const TypeSize DstSize = TRI.getRegSizeInBits(*DstRC);
unsigned SubReg;

// If the source bank doesn't support a subregister copy small enough,
Expand Down
49 changes: 27 additions & 22 deletions llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,17 +163,18 @@ AArch64RegisterBankInfo::AArch64RegisterBankInfo(
unsigned PartialMapSrcIdx = PMI_##RBNameSrc##Size - PMI_Min; \
(void)PartialMapDstIdx; \
(void)PartialMapSrcIdx; \
const ValueMapping *Map = getCopyMapping( \
AArch64::RBNameDst##RegBankID, AArch64::RBNameSrc##RegBankID, Size); \
const ValueMapping *Map = getCopyMapping(AArch64::RBNameDst##RegBankID, \
AArch64::RBNameSrc##RegBankID, \
TypeSize::getFixed(Size)); \
(void)Map; \
assert(Map[0].BreakDown == \
&AArch64GenRegisterBankInfo::PartMappings[PartialMapDstIdx] && \
Map[0].NumBreakDowns == 1 && #RBNameDst #Size \
" Dst is incorrectly initialized"); \
Map[0].NumBreakDowns == 1 && \
#RBNameDst #Size " Dst is incorrectly initialized"); \
assert(Map[1].BreakDown == \
&AArch64GenRegisterBankInfo::PartMappings[PartialMapSrcIdx] && \
Map[1].NumBreakDowns == 1 && #RBNameSrc #Size \
" Src is incorrectly initialized"); \
Map[1].NumBreakDowns == 1 && \
#RBNameSrc #Size " Src is incorrectly initialized"); \
\
} while (false)

Expand Down Expand Up @@ -218,7 +219,7 @@ AArch64RegisterBankInfo::AArch64RegisterBankInfo(

unsigned AArch64RegisterBankInfo::copyCost(const RegisterBank &A,
const RegisterBank &B,
TypeSize Size) const {
const TypeSize Size) const {
// What do we do with different size?
// copy are same size.
// Will introduce other hooks for different size:
Expand Down Expand Up @@ -258,6 +259,7 @@ AArch64RegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
case AArch64::QQQRegClassID:
case AArch64::QQQQRegClassID:
case AArch64::ZPRRegClassID:
case AArch64::ZPR_3bRegClassID:
return getRegBank(AArch64::FPRRegBankID);
case AArch64::GPR32commonRegClassID:
case AArch64::GPR32RegClassID:
Expand Down Expand Up @@ -304,7 +306,7 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
case TargetOpcode::G_OR: {
// 32 and 64-bit or can be mapped on either FPR or
// GPR for the same cost.
unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
TypeSize Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
if (Size != 32 && Size != 64)
break;

Expand All @@ -325,7 +327,7 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
return AltMappings;
}
case TargetOpcode::G_BITCAST: {
unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
TypeSize Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
if (Size != 32 && Size != 64)
break;

Expand Down Expand Up @@ -365,7 +367,7 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
return AltMappings;
}
case TargetOpcode::G_LOAD: {
unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
TypeSize Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
if (Size != 64)
break;

Expand All @@ -377,15 +379,17 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
InstructionMappings AltMappings;
const InstructionMapping &GPRMapping = getInstructionMapping(
/*ID*/ 1, /*Cost*/ 1,
getOperandsMapping({getValueMapping(PMI_FirstGPR, Size),
// Addresses are GPR 64-bit.
getValueMapping(PMI_FirstGPR, 64)}),
getOperandsMapping(
{getValueMapping(PMI_FirstGPR, Size),
// Addresses are GPR 64-bit.
getValueMapping(PMI_FirstGPR, TypeSize::getFixed(64))}),
/*NumOperands*/ 2);
const InstructionMapping &FPRMapping = getInstructionMapping(
/*ID*/ 2, /*Cost*/ 1,
getOperandsMapping({getValueMapping(PMI_FirstFPR, Size),
// Addresses are GPR 64-bit.
getValueMapping(PMI_FirstGPR, 64)}),
getOperandsMapping(
{getValueMapping(PMI_FirstFPR, Size),
// Addresses are GPR 64-bit.
getValueMapping(PMI_FirstGPR, TypeSize::getFixed(64))}),
/*NumOperands*/ 2);

AltMappings.push_back(&GPRMapping);
Expand Down Expand Up @@ -437,7 +441,7 @@ AArch64RegisterBankInfo::getSameKindOfOperandsMapping(
"This code is for instructions with 3 or less operands");

LLT Ty = MRI.getType(MI.getOperand(0).getReg());
unsigned Size = Ty.getSizeInBits();
TypeSize Size = Ty.getSizeInBits();
bool IsFPR = Ty.isVector() || isPreISelGenericFloatingPointOpcode(Opc);

PartialMappingIdx RBIdx = IsFPR ? PMI_FirstFPR : PMI_FirstGPR;
Expand Down Expand Up @@ -714,9 +718,9 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
// If both RB are null that means both registers are generic.
// We shouldn't be here.
assert(DstRB && SrcRB && "Both RegBank were nullptr");
unsigned Size = getSizeInBits(DstReg, MRI, TRI);
TypeSize Size = getSizeInBits(DstReg, MRI, TRI);
return getInstructionMapping(
DefaultMappingID, copyCost(*DstRB, *SrcRB, TypeSize::getFixed(Size)),
DefaultMappingID, copyCost(*DstRB, *SrcRB, Size),
getCopyMapping(DstRB->getID(), SrcRB->getID(), Size),
// We only care about the mapping of the destination.
/*NumOperands*/ 1);
Expand All @@ -727,15 +731,15 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
case TargetOpcode::G_BITCAST: {
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
LLT SrcTy = MRI.getType(MI.getOperand(1).getReg());
unsigned Size = DstTy.getSizeInBits();
TypeSize Size = DstTy.getSizeInBits();
bool DstIsGPR = !DstTy.isVector() && DstTy.getSizeInBits() <= 64;
bool SrcIsGPR = !SrcTy.isVector() && SrcTy.getSizeInBits() <= 64;
const RegisterBank &DstRB =
DstIsGPR ? AArch64::GPRRegBank : AArch64::FPRRegBank;
const RegisterBank &SrcRB =
SrcIsGPR ? AArch64::GPRRegBank : AArch64::FPRRegBank;
return getInstructionMapping(
DefaultMappingID, copyCost(DstRB, SrcRB, TypeSize::getFixed(Size)),
DefaultMappingID, copyCost(DstRB, SrcRB, Size),
getCopyMapping(DstRB.getID(), SrcRB.getID(), Size),
// We only care about the mapping of the destination for COPY.
/*NumOperands*/ Opc == TargetOpcode::G_BITCAST ? 2 : 1);
Expand Down Expand Up @@ -1126,7 +1130,8 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
LLT Ty = MRI.getType(MI.getOperand(Idx).getReg());
if (!Ty.isValid())
continue;
auto Mapping = getValueMapping(OpRegBankIdx[Idx], OpSize[Idx]);
auto Mapping =
getValueMapping(OpRegBankIdx[Idx], TypeSize::getFixed(OpSize[Idx]));
if (!Mapping->isValid())
return getInvalidInstructionMapping();

Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class AArch64GenRegisterBankInfo : public RegisterBankInfo {
PartialMappingIdx LastAlias,
ArrayRef<PartialMappingIdx> Order);

static unsigned getRegBankBaseIdxOffset(unsigned RBIdx, unsigned Size);
static unsigned getRegBankBaseIdxOffset(unsigned RBIdx, TypeSize Size);

/// Get the pointer to the ValueMapping representing the RegisterBank
/// at \p RBIdx with a size of \p Size.
Expand All @@ -80,13 +80,13 @@ class AArch64GenRegisterBankInfo : public RegisterBankInfo {
///
/// \pre \p RBIdx != PartialMappingIdx::None
static const RegisterBankInfo::ValueMapping *
getValueMapping(PartialMappingIdx RBIdx, unsigned Size);
getValueMapping(PartialMappingIdx RBIdx, TypeSize Size);

/// Get the pointer to the ValueMapping of the operands of a copy
/// instruction from the \p SrcBankID register bank to the \p DstBankID
/// register bank with a size of \p Size.
static const RegisterBankInfo::ValueMapping *
getCopyMapping(unsigned DstBankID, unsigned SrcBankID, unsigned Size);
getCopyMapping(unsigned DstBankID, unsigned SrcBankID, TypeSize Size);

/// Get the instruction mapping for G_FPEXT.
///
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve -global-isel -global-isel-abort=1 -aarch64-enable-gisel-sve=1 %s -o - | FileCheck %s

;; Test the correct usage of the Z registers with multiple SVE arguments.

define void @formal_argument_nxv16i8_2(<vscale x 16 x i8> %0, <vscale x 16 x i8> %1, ptr %p) {
; CHECK-LABEL: formal_argument_nxv16i8_2:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.b
; CHECK-NEXT: st1b { z0.b }, p0, [x0]
; CHECK-NEXT: st1b { z1.b }, p0, [x0]
; CHECK-NEXT: ret
store <vscale x 16 x i8> %0, ptr %p, align 16
store <vscale x 16 x i8> %1, ptr %p, align 16
ret void
}

define void @formal_argument_nxv16i8_8(
; CHECK-LABEL: formal_argument_nxv16i8_8:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.b
; CHECK-NEXT: st1b { z0.b }, p0, [x0]
; CHECK-NEXT: st1b { z1.b }, p0, [x0]
; CHECK-NEXT: st1b { z2.b }, p0, [x0]
; CHECK-NEXT: st1b { z3.b }, p0, [x0]
; CHECK-NEXT: st1b { z4.b }, p0, [x0]
; CHECK-NEXT: st1b { z5.b }, p0, [x0]
; CHECK-NEXT: st1b { z6.b }, p0, [x0]
; CHECK-NEXT: st1b { z7.b }, p0, [x0]
; CHECK-NEXT: ret
<vscale x 16 x i8> %0, <vscale x 16 x i8> %1, <vscale x 16 x i8> %2, <vscale x 16 x i8> %3,
<vscale x 16 x i8> %4, <vscale x 16 x i8> %5, <vscale x 16 x i8> %6, <vscale x 16 x i8> %7,
ptr %p) {

store <vscale x 16 x i8> %0, ptr %p, align 16
store <vscale x 16 x i8> %1, ptr %p, align 16
store <vscale x 16 x i8> %2, ptr %p, align 16
store <vscale x 16 x i8> %3, ptr %p, align 16
store <vscale x 16 x i8> %4, ptr %p, align 16
store <vscale x 16 x i8> %5, ptr %p, align 16
store <vscale x 16 x i8> %6, ptr %p, align 16
store <vscale x 16 x i8> %7, ptr %p, align 16
ret void
}
Loading
Loading