Skip to content

Commit 2fbfaff

Browse files
[AArch64][SME] Make getRegAllocationHints more specific for multi-vector loads (#123081)
getRegAllocationHints looks for ZPR2StridedOrContiguous load instructions which are used by FORM_TRANSPOSED_REG_TUPLE pseudos and adds all strided registers from this class to the list of hints. This patch changes getRegAllocationHints to restrict this list: - If the pseudo uses ZPRMul class, the first load must begin with a register which is a multiple of 2 or 4. - Only add a hint if it is part of a sequence of registers that do not already have any live intervals. This also contains changes to suggest hints when the load instructions and the FORM_TRANSPOSED pseudo use multi-vectors of different lengths, e.g. a pseudo with a 4-vector sequence of registers formed of one column extracted from four 2-vector loads.
1 parent 59a9a8f commit 2fbfaff

File tree

5 files changed

+807
-413
lines changed

5 files changed

+807
-413
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8765,17 +8765,9 @@ static bool checkZExtBool(SDValue Arg, const SelectionDAG &DAG) {
87658765
bool shouldUseFormStridedPseudo(MachineInstr &MI) {
87668766
MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
87678767

8768-
const TargetRegisterClass *RegClass = nullptr;
8769-
switch (MI.getOpcode()) {
8770-
case AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO:
8771-
RegClass = &AArch64::ZPR2StridedOrContiguousRegClass;
8772-
break;
8773-
case AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO:
8774-
RegClass = &AArch64::ZPR4StridedOrContiguousRegClass;
8775-
break;
8776-
default:
8777-
llvm_unreachable("Unexpected opcode.");
8778-
}
8768+
assert((MI.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO ||
8769+
MI.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO) &&
8770+
"Unexpected opcode.");
87798771

87808772
MCRegister SubReg = MCRegister::NoRegister;
87818773
for (unsigned I = 1; I < MI.getNumOperands(); ++I) {
@@ -8792,8 +8784,11 @@ bool shouldUseFormStridedPseudo(MachineInstr &MI) {
87928784
SubReg = OpSubReg;
87938785

87948786
MachineOperand *CopySrcOp = MRI.getOneDef(CopySrc.getReg());
8787+
const TargetRegisterClass *CopySrcClass =
8788+
MRI.getRegClass(CopySrcOp->getReg());
87958789
if (!CopySrcOp || !CopySrcOp->isReg() || OpSubReg != SubReg ||
8796-
MRI.getRegClass(CopySrcOp->getReg()) != RegClass)
8790+
(CopySrcClass != &AArch64::ZPR2StridedOrContiguousRegClass &&
8791+
CopySrcClass != &AArch64::ZPR4StridedOrContiguousRegClass))
87978792
return false;
87988793
}
87998794

llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp

Lines changed: 125 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "MCTargetDesc/AArch64InstPrinter.h"
2121
#include "llvm/ADT/BitVector.h"
2222
#include "llvm/BinaryFormat/Dwarf.h"
23+
#include "llvm/CodeGen/LiveRegMatrix.h"
2324
#include "llvm/CodeGen/MachineFrameInfo.h"
2425
#include "llvm/CodeGen/MachineInstrBuilder.h"
2526
#include "llvm/CodeGen/MachineRegisterInfo.h"
@@ -1097,7 +1098,11 @@ bool AArch64RegisterInfo::getRegAllocationHints(
10971098
Register VirtReg, ArrayRef<MCPhysReg> Order,
10981099
SmallVectorImpl<MCPhysReg> &Hints, const MachineFunction &MF,
10991100
const VirtRegMap *VRM, const LiveRegMatrix *Matrix) const {
1100-
const MachineRegisterInfo &MRI = MF.getRegInfo();
1101+
1102+
auto &ST = MF.getSubtarget<AArch64Subtarget>();
1103+
if (!ST.hasSME() || !ST.isStreaming())
1104+
return TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints, MF,
1105+
VRM);
11011106

11021107
// The SVE calling convention preserves registers Z8-Z23. As a result, there
11031108
// are no ZPR2Strided or ZPR4Strided registers that do not overlap with the
@@ -1107,26 +1112,127 @@ bool AArch64RegisterInfo::getRegAllocationHints(
11071112
// FORM_TRANSPOSED_REG_TUPLE pseudo, we want to favour reducing copy
11081113
// instructions over reducing the number of clobbered callee-save registers,
11091114
// so we add the strided registers as a hint.
1115+
const MachineRegisterInfo &MRI = MF.getRegInfo();
11101116
unsigned RegID = MRI.getRegClass(VirtReg)->getID();
1111-
// Look through uses of the register for FORM_TRANSPOSED_REG_TUPLE.
1112-
if ((RegID == AArch64::ZPR2StridedOrContiguousRegClassID ||
1113-
RegID == AArch64::ZPR4StridedOrContiguousRegClassID) &&
1114-
any_of(MRI.use_nodbg_instructions(VirtReg), [](const MachineInstr &Use) {
1115-
return Use.getOpcode() ==
1116-
AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO ||
1117-
Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO;
1118-
})) {
1119-
const TargetRegisterClass *StridedRC =
1120-
RegID == AArch64::ZPR2StridedOrContiguousRegClassID
1121-
? &AArch64::ZPR2StridedRegClass
1122-
: &AArch64::ZPR4StridedRegClass;
1123-
1124-
for (MCPhysReg Reg : Order)
1125-
if (StridedRC->contains(Reg))
1126-
Hints.push_back(Reg);
1117+
if (RegID == AArch64::ZPR2StridedOrContiguousRegClassID ||
1118+
RegID == AArch64::ZPR4StridedOrContiguousRegClassID) {
1119+
1120+
// Look through uses of the register for FORM_TRANSPOSED_REG_TUPLE.
1121+
for (const MachineInstr &Use : MRI.use_nodbg_instructions(VirtReg)) {
1122+
if (Use.getOpcode() != AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO &&
1123+
Use.getOpcode() != AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO)
1124+
continue;
1125+
1126+
unsigned UseOps = Use.getNumOperands() - 1;
1127+
const TargetRegisterClass *StridedRC;
1128+
switch (RegID) {
1129+
case AArch64::ZPR2StridedOrContiguousRegClassID:
1130+
StridedRC = &AArch64::ZPR2StridedRegClass;
1131+
break;
1132+
case AArch64::ZPR4StridedOrContiguousRegClassID:
1133+
StridedRC = &AArch64::ZPR4StridedRegClass;
1134+
break;
1135+
default:
1136+
llvm_unreachable("Unexpected RegID");
1137+
}
11271138

1128-
return TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints, MF,
1129-
VRM);
1139+
SmallVector<MCPhysReg, 4> StridedOrder;
1140+
for (MCPhysReg Reg : Order)
1141+
if (StridedRC->contains(Reg))
1142+
StridedOrder.push_back(Reg);
1143+
1144+
int OpIdx = Use.findRegisterUseOperandIdx(VirtReg, this);
1145+
assert(OpIdx != -1 && "Expected operand index from register use.");
1146+
1147+
unsigned TupleID = MRI.getRegClass(Use.getOperand(0).getReg())->getID();
1148+
bool IsMulZPR = TupleID == AArch64::ZPR2Mul2RegClassID ||
1149+
TupleID == AArch64::ZPR4Mul4RegClassID;
1150+
1151+
const MachineOperand *AssignedRegOp = llvm::find_if(
1152+
make_range(Use.operands_begin() + 1, Use.operands_end()),
1153+
[&VRM](const MachineOperand &Op) {
1154+
return VRM->hasPhys(Op.getReg());
1155+
});
1156+
1157+
// Example:
1158+
//
1159+
// When trying to find a suitable register allocation for VirtReg %v2 in:
1160+
//
1161+
// %v0:zpr2stridedorcontiguous = ld1 p0/z, [...]
1162+
// %v1:zpr2stridedorcontiguous = ld1 p0/z, [...]
1163+
// %v2:zpr2stridedorcontiguous = ld1 p0/z, [...]
1164+
// %v3:zpr2stridedorcontiguous = ld1 p0/z, [...]
1165+
// %v4:zpr4mul4 = FORM_TRANSPOSED_X4 %v0:0, %v1:0, %v2:0, %v3:0
1166+
//
1167+
// One such suitable allocation would be:
1168+
//
1169+
// { z0, z8 } = ld1 p0/z, [...]
1170+
// { z1, z9 } = ld1 p0/z, [...]
1171+
// { z2, z10 } = ld1 p0/z, [...]
1172+
// { z3, z11 } = ld1 p0/z, [...]
1173+
// { z0, z1, z2, z3 } =
1174+
// FORM_TRANSPOSED_X4 {z0, z8}:0, {z1, z9}:0, {z2, z10}:0, {z3, z11}:0
1175+
//
1176+
// Below we distinguish two cases when trying to find a register:
1177+
// * None of the registers used by FORM_TRANSPOSED_X4 have been assigned
1178+
// yet. In this case the code muse ensure that there are at least UseOps
1179+
// free consecutive registers. If IsMulZPR is true, then the first of
1180+
// registers must also be a multiple of UseOps, e.g. { z0, z1, z2, z3 }
1181+
// is valid but { z1, z2, z3, z5 } is not.
1182+
// * One or more of the registers used by FORM_TRANSPOSED_X4 is already
1183+
// assigned a physical register, which means only checking that a
1184+
// consectutive range of free tuple registers exists which includes
1185+
// the assigned register.
1186+
// e.g. in the example above, if { z0, z8 } is already allocated for
1187+
// %v0, we just need to ensure that { z1, z9 }, { z2, z10 } and
1188+
// { z3, z11 } are also free. If so, we add { z2, z10 }.
1189+
1190+
if (AssignedRegOp == Use.operands_end()) {
1191+
// There are no registers already assigned to any of the pseudo
1192+
// operands. Look for a valid starting register for the group.
1193+
for (unsigned I = 0; I < StridedOrder.size(); ++I) {
1194+
MCPhysReg Reg = StridedOrder[I];
1195+
SmallVector<MCPhysReg> Regs;
1196+
1197+
// If the FORM_TRANSPOSE nodes use the ZPRMul classes, the starting
1198+
// register of the first load should be a multiple of 2 or 4.
1199+
unsigned SubRegIdx = Use.getOperand(OpIdx).getSubReg();
1200+
if (IsMulZPR && (getSubReg(Reg, SubRegIdx) - AArch64::Z0) % UseOps !=
1201+
((unsigned)OpIdx - 1))
1202+
continue;
1203+
1204+
// In the example above, if VirtReg is the third operand of the
1205+
// tuple (%v2) and Reg == Z2_Z10, then we need to make sure that
1206+
// Z0_Z8, Z1_Z9 and Z3_Z11 are also available.
1207+
auto IsFreeConsecutiveReg = [&](unsigned UseOp) {
1208+
unsigned R = Reg - (OpIdx - 1) + UseOp;
1209+
return StridedRC->contains(R) &&
1210+
(UseOp == 0 ||
1211+
((getSubReg(R, AArch64::zsub0) - AArch64::Z0) ==
1212+
(getSubReg(R - 1, AArch64::zsub0) - AArch64::Z0) + 1)) &&
1213+
!Matrix->isPhysRegUsed(R);
1214+
};
1215+
if (all_of(iota_range<unsigned>(0U, UseOps, /*Inclusive=*/false),
1216+
IsFreeConsecutiveReg))
1217+
Hints.push_back(Reg);
1218+
}
1219+
} else {
1220+
// At least one operand already has a physical register assigned.
1221+
// Find the starting sub-register of this and use it to work out the
1222+
// correct strided register to suggest based on the current op index.
1223+
MCPhysReg TargetStartReg =
1224+
getSubReg(VRM->getPhys(AssignedRegOp->getReg()), AArch64::zsub0) +
1225+
(OpIdx - AssignedRegOp->getOperandNo());
1226+
1227+
for (unsigned I = 0; I < StridedOrder.size(); ++I)
1228+
if (getSubReg(StridedOrder[I], AArch64::zsub0) == TargetStartReg)
1229+
Hints.push_back(StridedOrder[I]);
1230+
}
1231+
1232+
if (!Hints.empty())
1233+
return TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints,
1234+
MF, VRM);
1235+
}
11301236
}
11311237

11321238
for (MachineInstr &MI : MRI.def_instructions(VirtReg)) {

0 commit comments

Comments
 (0)