Skip to content

[AArch64][SME2] Improve register allocation of multi-vector SME intrinsics #116399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ class AArch64ExpandPseudo : public MachineFunctionPass {
TargetRegisterClass ContiguousClass,
TargetRegisterClass StridedClass,
unsigned ContiguousOpc, unsigned StridedOpc);
bool expandFormTuplePseudo(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
MachineBasicBlock::iterator &NextMBBI,
unsigned Size);
bool expandMOVImm(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
unsigned BitSize);

Expand Down Expand Up @@ -1142,6 +1146,32 @@ bool AArch64ExpandPseudo::expandMultiVecPseudo(
return true;
}

bool AArch64ExpandPseudo::expandFormTuplePseudo(
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
MachineBasicBlock::iterator &NextMBBI, unsigned Size) {
assert(Size == 2 || Size == 4 && "Invalid Tuple Size");
MachineInstr &MI = *MBBI;
Register ReturnTuple = MI.getOperand(0).getReg();

const TargetRegisterInfo *TRI =
MBB.getParent()->getSubtarget().getRegisterInfo();
for (unsigned I = 0; I < Size; ++I) {
Register FormTupleOpReg = MI.getOperand(I + 1).getReg();
Register ReturnTupleSubReg =
TRI->getSubReg(ReturnTuple, AArch64::zsub0 + I);
// Add copies to ensure the subregisters remain in the correct order
// for any contigious operation they are used by.
if (FormTupleOpReg != ReturnTupleSubReg)
BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(AArch64::ORR_ZZZ))
.addReg(ReturnTupleSubReg, RegState::Define)
.addReg(FormTupleOpReg)
.addReg(FormTupleOpReg);
}

MI.eraseFromParent();
return true;
}

/// If MBBI references a pseudo instruction that should be expanded here,
/// do the expansion and return true. Otherwise return false.
bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
Expand Down Expand Up @@ -1724,6 +1754,10 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
return expandMultiVecPseudo(
MBB, MBBI, AArch64::ZPR4RegClass, AArch64::ZPR4StridedRegClass,
AArch64::LDNT1D_4Z, AArch64::LDNT1D_4Z_STRIDED);
case AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO:
return expandFormTuplePseudo(MBB, MBBI, NextMBBI, 2);
case AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO:
return expandFormTuplePseudo(MBB, MBBI, NextMBBI, 4);
}
return false;
}
Expand Down
71 changes: 71 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8641,6 +8641,56 @@ static bool checkZExtBool(SDValue Arg, const SelectionDAG &DAG) {
return ZExtBool;
}

// The FORM_TRANSPOSED_REG_TUPLE pseudo should only be used if the
// input operands are copy nodes where the source register is in a
// StridedOrContiguous class. For example:
//
// %3:zpr2stridedorcontiguous = LD1B_2Z_IMM_PSEUDO ..
// %4:zpr = COPY %3.zsub1:zpr2stridedorcontiguous
// %5:zpr = COPY %3.zsub0:zpr2stridedorcontiguous
// %6:zpr2stridedorcontiguous = LD1B_2Z_PSEUDO ..
// %7:zpr = COPY %6.zsub1:zpr2stridedorcontiguous
// %8:zpr = COPY %6.zsub0:zpr2stridedorcontiguous
// %9:zpr2mul2 = FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO %5:zpr, %8:zpr
//
bool shouldUseFormStridedPseudo(MachineInstr &MI) {
MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();

const TargetRegisterClass *RegClass = nullptr;
switch (MI.getOpcode()) {
case AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO:
RegClass = &AArch64::ZPR2StridedOrContiguousRegClass;
break;
case AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO:
RegClass = &AArch64::ZPR4StridedOrContiguousRegClass;
break;
default:
llvm_unreachable("Unexpected opcode.");
}

MCRegister SubReg = MCRegister::NoRegister;
for (unsigned I = 1; I < MI.getNumOperands(); ++I) {
MachineOperand &MO = MI.getOperand(I);
assert(MO.isReg() && "Unexpected operand to FORM_TRANSPOSED_REG_TUPLE");

MachineOperand *Def = MRI.getOneDef(MO.getReg());
if (!Def || !Def->getParent()->isCopy())
return false;

const MachineOperand &CopySrc = Def->getParent()->getOperand(1);
unsigned OpSubReg = CopySrc.getSubReg();
if (SubReg == MCRegister::NoRegister)
SubReg = OpSubReg;

MachineOperand *CopySrcOp = MRI.getOneDef(CopySrc.getReg());
if (!CopySrcOp || !CopySrcOp->isReg() || OpSubReg != SubReg ||
MRI.getRegClass(CopySrcOp->getReg()) != RegClass)
return false;
}

return true;
}

void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
SDNode *Node) const {
// Live-in physreg copies that are glued to SMSTART are applied as
Expand All @@ -8666,6 +8716,27 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
}
}

if (MI.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO ||
MI.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO) {
// If input values to the FORM_TRANSPOSED_REG_TUPLE pseudo aren't copies
// from a StridedOrContiguous class, fall back on REG_SEQUENCE node.
if (shouldUseFormStridedPseudo(MI))
return;

const TargetInstrInfo *TII = Subtarget->getInstrInfo();
MachineInstrBuilder MIB = BuildMI(*MI.getParent(), MI, MI.getDebugLoc(),
TII->get(TargetOpcode::REG_SEQUENCE),
MI.getOperand(0).getReg());

for (unsigned I = 1; I < MI.getNumOperands(); ++I) {
MIB.add(MI.getOperand(I));
MIB.addImm(AArch64::zsub0 + (I - 1));
}

MI.eraseFromParent();
return;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
MIB.addImm(SubRegs[I - 1]);
MIB.addImm(AArch64::zsub0 + (I-1));

Then you can remove SubRegs[].


// Add an implicit use of 'VG' for ADDXri/SUBXri, which are instructions that
// have nothing to do with VG, were it not that they are used to materialise a
// frame-address. If they contain a frame-index to a scalable vector, this
Expand Down
52 changes: 52 additions & 0 deletions llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,58 @@ unsigned AArch64RegisterInfo::getRegPressureLimit(const TargetRegisterClass *RC,
}
}

// FORM_TRANSPOSED_REG_TUPLE nodes are created to improve register allocation
// where a consecutive multi-vector tuple is constructed from the same indices
// of multiple strided loads. This may still result in unnecessary copies
// between the loads and the tuple. Here we try to return a hint to assign the
// contiguous ZPRMulReg starting at the same register as the first operand of
// the pseudo, which should be a subregister of the first strided load.
//
// For example, if the first strided load has been assigned $z16_z20_z24_z28
// and the operands of the pseudo are each accessing subregister zsub2, we
// should look through through Order to find a contiguous register which
// begins with $z24 (i.e. $z24_z25_z26_z27).
//
bool AArch64RegisterInfo::getRegAllocationHints(
Register VirtReg, ArrayRef<MCPhysReg> Order,
SmallVectorImpl<MCPhysReg> &Hints, const MachineFunction &MF,
const VirtRegMap *VRM, const LiveRegMatrix *Matrix) const {
const MachineRegisterInfo &MRI = MF.getRegInfo();

for (MachineInstr &MI : MRI.def_instructions(VirtReg)) {
if (MI.getOpcode() != AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO &&
MI.getOpcode() != AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO)
return TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints,
MF, VRM);

unsigned FirstOpSubReg = MI.getOperand(1).getSubReg();
switch (FirstOpSubReg) {
case AArch64::zsub0:
case AArch64::zsub1:
case AArch64::zsub2:
case AArch64::zsub3:
break;
default:
continue;
}

// Look up the physical register mapped to the first operand of the pseudo.
Register FirstOpVirtReg = MI.getOperand(1).getReg();
if (!VRM->hasPhys(FirstOpVirtReg))
continue;

MCRegister TupleStartReg =
getSubReg(VRM->getPhys(FirstOpVirtReg), FirstOpSubReg);
for (unsigned I = 0; I < Order.size(); ++I)
if (MCRegister R = getSubReg(Order[I], AArch64::zsub0))
if (R == TupleStartReg)
Hints.push_back(Order[I]);
}

return TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints, MF,
VRM);
}

unsigned AArch64RegisterInfo::getLocalAddressRegister(
const MachineFunction &MF) const {
const auto &MFI = MF.getFrameInfo();
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/AArch64/AArch64RegisterInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ class AArch64RegisterInfo final : public AArch64GenRegisterInfo {
unsigned getRegPressureLimit(const TargetRegisterClass *RC,
MachineFunction &MF) const override;

bool getRegAllocationHints(Register VirtReg, ArrayRef<MCPhysReg> Order,
SmallVectorImpl<MCPhysReg> &Hints,
const MachineFunction &MF, const VirtRegMap *VRM,
const LiveRegMatrix *Matrix) const override;

unsigned getLocalAddressRegister(const MachineFunction &MF) const;
bool regNeedsCFI(unsigned Reg, unsigned &RegToUseForCFI) const;

Expand Down
28 changes: 26 additions & 2 deletions llvm/lib/Target/AArch64/SMEInstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,30 @@ def tileslicerange0s4 : ComplexPattern<i32, 2, "SelectSMETileSlice<0, 4>", []>;

def am_sme_indexed_b4 :ComplexPattern<iPTR, 2, "SelectAddrModeIndexedSVE<0,15>", [], [SDNPWantRoot]>;

// The FORM_TRANSPOSED_REG_TUPLE pseudos defined below are intended to
// improve register allocation for intrinsics which use strided and contiguous
// multi-vector registers, avoiding unnecessary copies.
// If the operands of the pseudo are copies where the source register is in
// the StridedOrContiguous class, the pseudo is used to provide a hint to the
// register allocator suggesting a contigious multi-vector register which
// matches the subregister sequence used by the operands.
// If the operands do not match this pattern, the pseudos are expanded
// to a REG_SEQUENCE using the post-isel hook.

def FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO :
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs a description of why we add these pseudos, and a comment that we expand them to REG_SEQUENCE with the post-isel hook if they don't meet certain criteria.

Pseudo<(outs ZPR2Mul2:$tup),
(ins ZPR:$zn0, ZPR:$zn1), []>, Sched<[]>{
let hasSideEffects = 0;
let hasPostISelHook = 1;
}

def FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO :
Pseudo<(outs ZPR4Mul4:$tup),
(ins ZPR:$zn0, ZPR:$zn1, ZPR:$zn2, ZPR:$zn3), []>, Sched<[]>{
let hasSideEffects = 0;
let hasPostISelHook = 1;
}

def SDTZALoadStore : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisInt<2>]>;
def AArch64SMELdr : SDNode<"AArch64ISD::SME_ZA_LDR", SDTZALoadStore,
[SDNPHasChain, SDNPSideEffect, SDNPMayLoad]>;
Expand Down Expand Up @@ -172,14 +196,14 @@ class SME2_ZA_TwoOp_VG2_Multi_Index_Pat<string name, SDPatternOperator intrinsic
Operand imm_ty, ComplexPattern tileslice>
: Pat<(intrinsic (i32 (tileslice MatrixIndexGPR32Op8_11:$base, index_ty:$offset)), vt:$Zn1, vt:$Zn2, vt:$Zm, (i32 imm_ty:$i)),
(!cast<Instruction>(name # _PSEUDO) $base, $offset,
(REG_SEQUENCE ZPR2Mul2, vt:$Zn1, zsub0, vt:$Zn2, zsub1), zpr_ty:$Zm, imm_ty:$i)>;
(FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO vt:$Zn1,vt:$Zn2), zpr_ty:$Zm, imm_ty:$i)>;

class SME2_ZA_TwoOp_VG4_Multi_Index_Pat<string name, SDPatternOperator intrinsic, Operand index_ty, ZPRRegOp zpr_ty, ValueType vt,
Operand imm_ty, ComplexPattern tileslice>
: Pat<(intrinsic (i32 (tileslice MatrixIndexGPR32Op8_11:$base, index_ty:$offset)),
vt:$Zn1, vt:$Zn2, vt:$Zn3, vt:$Zn4, vt:$Zm, (i32 imm_ty:$i)),
(!cast<Instruction>(name # _PSEUDO) $base, $offset,
(REG_SEQUENCE ZPR4Mul4, vt:$Zn1, zsub0, vt:$Zn2, zsub1, vt:$Zn3, zsub2, vt:$Zn4, zsub3),
(FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO vt:$Zn1, vt:$Zn2, vt:$Zn3, vt:$Zn4),
zpr_ty:$Zm, imm_ty:$i)>;

class SME2_Sat_Shift_VG2_Pat<string name, SDPatternOperator intrinsic, ValueType out_vt, ValueType in_vt, Operand imm_ty>
Expand Down
Loading
Loading