-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[AArch64][SME] Make getRegAllocationHints more specific for multi-vector loads #123081
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AArch64][SME] Make getRegAllocationHints more specific for multi-vector loads #123081
Conversation
@llvm/pr-subscribers-backend-aarch64 Author: Kerry McLaughlin (kmclaughlin-arm) ChangesgetRegAllocationHints looks for ZPR2StridedOrContiguous load instructions
Patch is 64.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123081.diff 4 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
index 5973b63b5a8024..aac1dc9cb5c062 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
@@ -20,6 +20,7 @@
#include "MCTargetDesc/AArch64InstPrinter.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/BinaryFormat/Dwarf.h"
+#include "llvm/CodeGen/LiveRegMatrix.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
@@ -1107,23 +1108,83 @@ bool AArch64RegisterInfo::getRegAllocationHints(
// FORM_TRANSPOSED_REG_TUPLE pseudo, we want to favour reducing copy
// instructions over reducing the number of clobbered callee-save registers,
// so we add the strided registers as a hint.
+ const MachineInstr *TupleInst = nullptr;
unsigned RegID = MRI.getRegClass(VirtReg)->getID();
// Look through uses of the register for FORM_TRANSPOSED_REG_TUPLE.
if ((RegID == AArch64::ZPR2StridedOrContiguousRegClassID ||
RegID == AArch64::ZPR4StridedOrContiguousRegClassID) &&
- any_of(MRI.use_nodbg_instructions(VirtReg), [](const MachineInstr &Use) {
- return Use.getOpcode() ==
- AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO ||
- Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO;
+ any_of(MRI.use_nodbg_instructions(VirtReg), [&TupleInst](
+ const MachineInstr &Use) {
+ bool IsTuple =
+ Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO ||
+ Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO;
+ TupleInst = &Use;
+ return IsTuple;
})) {
- const TargetRegisterClass *StridedRC =
- RegID == AArch64::ZPR2StridedOrContiguousRegClassID
- ? &AArch64::ZPR2StridedRegClass
- : &AArch64::ZPR4StridedRegClass;
+ unsigned LdOps = TupleInst->getNumOperands() - 1;
+ const TargetRegisterClass *StridedRC = LdOps == 2
+ ? &AArch64::ZPR2StridedRegClass
+ : &AArch64::ZPR4StridedRegClass;
+ SmallVector<MCPhysReg, 4> StridedOrder;
for (MCPhysReg Reg : Order)
if (StridedRC->contains(Reg))
- Hints.push_back(Reg);
+ StridedOrder.push_back(Reg);
+
+ int OpIdx = TupleInst->findRegisterUseOperandIdx(VirtReg, this);
+ if (OpIdx == -1)
+ return TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints,
+ MF, VRM);
+
+ unsigned TupleID =
+ MRI.getRegClass(TupleInst->getOperand(0).getReg())->getID();
+ bool IsMulZPR = TupleID == AArch64::ZPR2Mul2RegClassID ||
+ TupleID == AArch64::ZPR4Mul4RegClassID;
+
+ if (OpIdx == 1) {
+ for (unsigned I = 0; I < StridedOrder.size(); ++I) {
+ MCPhysReg Reg = StridedOrder[I];
+ unsigned FirstReg = getSubReg(Reg, AArch64::zsub0);
+
+ // If the FORM_TRANSPOSE nodes use the ZPRMul classes, the starting
+ // register of the first load should be a multiple of 2 or 4.
+ if (IsMulZPR &&
+ (getSubReg(Reg, AArch64::zsub0) - AArch64::Z0) % LdOps != 0)
+ continue;
+ // Skip this register if it has any live intervals assigned.
+ if (Matrix->isPhysRegUsed(Reg))
+ continue;
+
+ bool CanAssign = true;
+ for (unsigned Next = 1; Next < LdOps; ++Next) {
+ // Ensure we can assign enough registers from the list for all loads.
+ if (I + Next >= StridedOrder.size()) {
+ CanAssign = false;
+ break;
+ }
+ // Ensure the subsequent registers are not live and that the starting
+ // sub-registers are sequential.
+ MCPhysReg NextReg = StridedOrder[I + Next];
+ if (Matrix->isPhysRegUsed(NextReg) ||
+ (getSubReg(NextReg, AArch64::zsub0) != FirstReg + Next)) {
+ CanAssign = false;
+ break;
+ }
+ }
+ if (CanAssign)
+ Hints.push_back(Reg);
+ }
+ } else if (VRM->hasPhys(TupleInst->getOperand(1).getReg())) {
+ // This is not the first load in the sequence. Find the register
+ // assigned to the first and match to a strided reg in the list.
+ MCPhysReg FirstLoadPhysReg =
+ VRM->getPhys(TupleInst->getOperand(1).getReg());
+ for (unsigned I = 0; I < StridedOrder.size(); ++I) {
+ if (StridedOrder[I] == FirstLoadPhysReg &&
+ (I + (OpIdx - 1) < StridedOrder.size()))
+ Hints.push_back(StridedOrder[I + (OpIdx - 1)]);
+ }
+ }
return TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints, MF,
VRM);
diff --git a/llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll b/llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll
index 86ed63d743713c..4abc28c4863421 100644
--- a/llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll
+++ b/llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll
@@ -615,15 +615,17 @@ define void @udot_form_2x_tuple_svecc(ptr %ptr, i64 %stride, <vscale x 16 x i8>
; CHECK-NEXT: str p8, [sp, #7, mul vl] // 2-byte Folded Spill
; CHECK-NEXT: ptrue pn8.b
; CHECK-NEXT: mov w8, wzr
-; CHECK-NEXT: str z9, [sp, #1, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT: str z8, [sp, #2, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT: ld1b { z0.b, z8.b }, pn8/z, [x0]
-; CHECK-NEXT: ld1b { z1.b, z9.b }, pn8/z, [x0, x1]
-; CHECK-NEXT: udot za.s[w8, 0, vgx2], { z0.b, z1.b }, z0.b[0]
-; CHECK-NEXT: udot za.s[w8, 0, vgx2], { z8.b, z9.b }, z0.b[0]
-; CHECK-NEXT: ldr z9, [sp, #1, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT: ldr z8, [sp, #2, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: str z11, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: ptrue p0.b
+; CHECK-NEXT: str z10, [sp, #2, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: ld1b { z2.b, z10.b }, pn8/z, [x0]
+; CHECK-NEXT: ld1b { z3.b, z11.b }, pn8/z, [x0, x1]
+; CHECK-NEXT: udot za.s[w8, 0, vgx2], { z2.b, z3.b }, z0.b[0]
+; CHECK-NEXT: udot za.s[w8, 0, vgx2], { z10.b, z11.b }, z0.b[0]
+; CHECK-NEXT: ldr z11, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: ldr z10, [sp, #2, mul vl] // 16-byte Folded Reload
; CHECK-NEXT: ldr p8, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT: st1b { z0.b }, p0, [x0]
; CHECK-NEXT: addvl sp, sp, #3
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
@@ -638,6 +640,7 @@ entry:
%6 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %4, 1
tail call void @llvm.aarch64.sme.udot.lane.za32.vg1x2.nxv16i8(i32 0, <vscale x 16 x i8> %2, <vscale x 16 x i8> %5, <vscale x 16 x i8> poison, i32 0)
tail call void @llvm.aarch64.sme.udot.lane.za32.vg1x2.nxv16i8(i32 0, <vscale x 16 x i8> %3, <vscale x 16 x i8> %6, <vscale x 16 x i8> poison, i32 0)
+ store <vscale x 16 x i8> %scalable_arg, ptr %ptr
ret void
}
@@ -699,33 +702,35 @@ define void @udot_form_4x_tuple_svecc(ptr %ptr, i64 %stride, <vscale x 16 x i8>
; CHECK-NEXT: lsl x9, x1, #1
; CHECK-NEXT: str p8, [sp, #7, mul vl] // 2-byte Folded Spill
; CHECK-NEXT: ptrue pn8.b
-; CHECK-NEXT: str z15, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: str z23, [sp, #1, mul vl] // 16-byte Folded Spill
; CHECK-NEXT: mov w8, wzr
-; CHECK-NEXT: str z14, [sp, #2, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: ptrue p0.b
+; CHECK-NEXT: str z22, [sp, #2, mul vl] // 16-byte Folded Spill
; CHECK-NEXT: add x10, x9, x1
-; CHECK-NEXT: str z13, [sp, #3, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT: str z12, [sp, #4, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT: str z11, [sp, #5, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT: str z10, [sp, #6, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT: str z9, [sp, #7, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT: str z8, [sp, #8, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT: ld1b { z0.b, z4.b, z8.b, z12.b }, pn8/z, [x0]
-; CHECK-NEXT: ld1b { z1.b, z5.b, z9.b, z13.b }, pn8/z, [x0, x1]
-; CHECK-NEXT: ld1b { z2.b, z6.b, z10.b, z14.b }, pn8/z, [x0, x9]
-; CHECK-NEXT: ld1b { z3.b, z7.b, z11.b, z15.b }, pn8/z, [x0, x10]
-; CHECK-NEXT: udot za.s[w8, 0, vgx4], { z0.b - z3.b }, z0.b[0]
-; CHECK-NEXT: udot za.s[w8, 0, vgx4], { z4.b - z7.b }, z0.b[0]
-; CHECK-NEXT: udot za.s[w8, 0, vgx4], { z8.b - z11.b }, z0.b[0]
-; CHECK-NEXT: udot za.s[w8, 0, vgx4], { z12.b - z15.b }, z0.b[0]
-; CHECK-NEXT: ldr z15, [sp, #1, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT: ldr z14, [sp, #2, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT: ldr z13, [sp, #3, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT: ldr z12, [sp, #4, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT: ldr z11, [sp, #5, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT: ldr z10, [sp, #6, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT: ldr z9, [sp, #7, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT: ldr z8, [sp, #8, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: str z21, [sp, #3, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: str z20, [sp, #4, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: str z19, [sp, #5, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: str z18, [sp, #6, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: str z17, [sp, #7, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: str z16, [sp, #8, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: ld1b { z16.b, z20.b, z24.b, z28.b }, pn8/z, [x0]
+; CHECK-NEXT: ld1b { z17.b, z21.b, z25.b, z29.b }, pn8/z, [x0, x1]
+; CHECK-NEXT: ld1b { z18.b, z22.b, z26.b, z30.b }, pn8/z, [x0, x9]
+; CHECK-NEXT: ld1b { z19.b, z23.b, z27.b, z31.b }, pn8/z, [x0, x10]
+; CHECK-NEXT: udot za.s[w8, 0, vgx4], { z16.b - z19.b }, z0.b[0]
+; CHECK-NEXT: udot za.s[w8, 0, vgx4], { z20.b - z23.b }, z0.b[0]
+; CHECK-NEXT: udot za.s[w8, 0, vgx4], { z24.b - z27.b }, z0.b[0]
+; CHECK-NEXT: udot za.s[w8, 0, vgx4], { z28.b - z31.b }, z0.b[0]
+; CHECK-NEXT: ldr z23, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: ldr z22, [sp, #2, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: ldr z21, [sp, #3, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: ldr z20, [sp, #4, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: ldr z19, [sp, #5, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: ldr z18, [sp, #6, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: ldr z17, [sp, #7, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: ldr z16, [sp, #8, mul vl] // 16-byte Folded Reload
; CHECK-NEXT: ldr p8, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT: st1b { z0.b }, p0, [x0]
; CHECK-NEXT: addvl sp, sp, #9
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
@@ -760,6 +765,7 @@ entry:
tail call void @llvm.aarch64.sme.udot.lane.za32.vg1x4.nxv16i8(i32 0, <vscale x 16 x i8> %3, <vscale x 16 x i8> %8, <vscale x 16 x i8> %13, <vscale x 16 x i8> %18, <vscale x 16 x i8> poison, i32 0)
tail call void @llvm.aarch64.sme.udot.lane.za32.vg1x4.nxv16i8(i32 0, <vscale x 16 x i8> %4, <vscale x 16 x i8> %9, <vscale x 16 x i8> %14, <vscale x 16 x i8> %19, <vscale x 16 x i8> poison, i32 0)
tail call void @llvm.aarch64.sme.udot.lane.za32.vg1x4.nxv16i8(i32 0, <vscale x 16 x i8> %5, <vscale x 16 x i8> %10, <vscale x 16 x i8> %15, <vscale x 16 x i8> %20, <vscale x 16 x i8> poison, i32 0)
+ store <vscale x 16 x i8> %scalable_arg, ptr %ptr
ret void
}
@@ -863,15 +869,17 @@ define void @usdot_form_2x_tuple_svecc(ptr %ptr, i64 %stride, <vscale x 16 x i8>
; CHECK-NEXT: str p8, [sp, #7, mul vl] // 2-byte Folded Spill
; CHECK-NEXT: ptrue pn8.b
; CHECK-NEXT: mov w8, wzr
-; CHECK-NEXT: str z9, [sp, #1, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT: str z8, [sp, #2, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT: ld1b { z0.b, z8.b }, pn8/z, [x0]
-; CHECK-NEXT: ld1b { z1.b, z9.b }, pn8/z, [x0, x1]
-; CHECK-NEXT: usdot za.s[w8, 0, vgx2], { z0.b, z1.b }, z0.b[0]
-; CHECK-NEXT: usdot za.s[w8, 0, vgx2], { z8.b, z9.b }, z0.b[0]
-; CHECK-NEXT: ldr z9, [sp, #1, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT: ldr z8, [sp, #2, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: str z11, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: ptrue p0.b
+; CHECK-NEXT: str z10, [sp, #2, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: ld1b { z2.b, z10.b }, pn8/z, [x0]
+; CHECK-NEXT: ld1b { z3.b, z11.b }, pn8/z, [x0, x1]
+; CHECK-NEXT: usdot za.s[w8, 0, vgx2], { z2.b, z3.b }, z0.b[0]
+; CHECK-NEXT: usdot za.s[w8, 0, vgx2], { z10.b, z11.b }, z0.b[0]
+; CHECK-NEXT: ldr z11, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: ldr z10, [sp, #2, mul vl] // 16-byte Folded Reload
; CHECK-NEXT: ldr p8, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT: st1b { z0.b }, p0, [x0]
; CHECK-NEXT: addvl sp, sp, #3
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
@@ -886,6 +894,7 @@ entry:
%6 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %4, 1
tail call void @llvm.aarch64.sme.usdot.lane.za32.vg1x2.nxv16i8(i32 0, <vscale x 16 x i8> %2, <vscale x 16 x i8> %5, <vscale x 16 x i8> poison, i32 0)
tail call void @llvm.aarch64.sme.usdot.lane.za32.vg1x2.nxv16i8(i32 0, <vscale x 16 x i8> %3, <vscale x 16 x i8> %6, <vscale x 16 x i8> poison, i32 0)
+ store <vscale x 16 x i8> %scalable_arg, ptr %ptr
ret void
}
@@ -947,33 +956,35 @@ define void @usdot_form_4x_tuple_svecc(ptr %ptr, i64 %stride, <vscale x 16 x i8>
; CHECK-NEXT: lsl x9, x1, #1
; CHECK-NEXT: str p8, [sp, #7, mul vl] // 2-byte Folded Spill
; CHECK-NEXT: ptrue pn8.b
-; CHECK-NEXT: str z15, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: str z23, [sp, #1, mul vl] // 16-byte Folded Spill
; CHECK-NEXT: mov w8, wzr
-; CHECK-NEXT: str z14, [sp, #2, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: ptrue p0.b
+; CHECK-NEXT: str z22, [sp, #2, mul vl] // 16-byte Folded Spill
; CHECK-NEXT: add x10, x9, x1
-; CHECK-NEXT: str z13, [sp, #3, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT: str z12, [sp, #4, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT: str z11, [sp, #5, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT: str z10, [sp, #6, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT: str z9, [sp, #7, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT: str z8, [sp, #8, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT: ld1b { z0.b, z4.b, z8.b, z12.b }, pn8/z, [x0]
-; CHECK-NEXT: ld1b { z1.b, z5.b, z9.b, z13.b }, pn8/z, [x0, x1]
-; CHECK-NEXT: ld1b { z2.b, z6.b, z10.b, z14.b }, pn8/z, [x0, x9]
-; CHECK-NEXT: ld1b { z3.b, z7.b, z11.b, z15.b }, pn8/z, [x0, x10]
-; CHECK-NEXT: usdot za.s[w8, 0, vgx4], { z0.b - z3.b }, z0.b[0]
-; CHECK-NEXT: usdot za.s[w8, 0, vgx4], { z4.b - z7.b }, z0.b[0]
-; CHECK-NEXT: usdot za.s[w8, 0, vgx4], { z8.b - z11.b }, z0.b[0]
-; CHECK-NEXT: usdot za.s[w8, 0, vgx4], { z12.b - z15.b }, z0.b[0]
-; CHECK-NEXT: ldr z15, [sp, #1, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT: ldr z14, [sp, #2, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT: ldr z13, [sp, #3, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT: ldr z12, [sp, #4, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT: ldr z11, [sp, #5, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT: ldr z10, [sp, #6, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT: ldr z9, [sp, #7, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT: ldr z8, [sp, #8, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: str z21, [sp, #3, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: str z20, [sp, #4, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: str z19, [sp, #5, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: str z18, [sp, #6, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: str z17, [sp, #7, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: str z16, [sp, #8, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: ld1b { z16.b, z20.b, z24.b, z28.b }, pn8/z, [x0]
+; CHECK-NEXT: ld1b { z17.b, z21.b, z25.b, z29.b }, pn8/z, [x0, x1]
+; CHECK-NEXT: ld1b { z18.b, z22.b, z26.b, z30.b }, pn8/z, [x0, x9]
+; CHECK-NEXT: ld1b { z19.b, z23.b, z27.b, z31.b }, pn8/z, [x0, x10]
+; CHECK-NEXT: usdot za.s[w8, 0, vgx4], { z16.b - z19.b }, z0.b[0]
+; CHECK-NEXT: usdot za.s[w8, 0, vgx4], { z20.b - z23.b }, z0.b[0]
+; CHECK-NEXT: usdot za.s[w8, 0, vgx4], { z24.b - z27.b }, z0.b[0]
+; CHECK-NEXT: usdot za.s[w8, 0, vgx4], { z28.b - z31.b }, z0.b[0]
+; CHECK-NEXT: ldr z23, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: ldr z22, [sp, #2, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: ldr z21, [sp, #3, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: ldr z20, [sp, #4, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: ldr z19, [sp, #5, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: ldr z18, [sp, #6, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: ldr z17, [sp, #7, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: ldr z16, [sp, #8, mul vl] // 16-byte Folded Reload
; CHECK-NEXT: ldr p8, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT: st1b { z0.b }, p0, [x0]
; CHECK-NEXT: addvl sp, sp, #9
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
@@ -1008,6 +1019,7 @@ entry:
tail call void @llvm.aarch64.sme.usdot.lane.za32.vg1x4.nxv16i8(i32 0, <vscale x 16 x i8> %3, <vscale x 16 x i8> %8, <vscale x 16 x i8> %13, <vscale x 16 x i8> %18, <vscale x 16 x i8> poison, i32 0)
tail call void @llvm.aarch64.sme.usdot.lane.za32.vg1x4.nxv16i8(i32 0, <vscale x 16 x i8> %4, <vscale x 16 x i8> %9, <vscale x 16 x i8> %14, <vscale x 16 x i8> %19, <vscale x 16 x i8> poison, i32 0)
tail call void @llvm.aarch64.sme.usdot.lane.za32.vg1x4.nxv16i8(i32 0, <vscale x 16 x i8> %5, <vscale x 16 x i8> %10, <vscale x 16 x i8> %15, <vscale x 16 x i8> %20, <vscale x 16 x i8> poison, i32 0)
+ store <vscale x 16 x i8> %scalable_arg, ptr %ptr
ret void
}
@@ -1113,15 +1125,17 @@ define void @sdot_form_2x_tuple_svecc(ptr %ptr, i64 %stride, <vscale x 16 x i8>
; CHECK-NEXT: str p8, [sp, #7, mul vl] // 2-byte Folded Spill
; CHECK-NEXT: ptrue pn8.b
; CHECK-NEXT: mov w8, wzr
-; CHECK-NEXT: str z9, [sp, #1, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT: str z8, [sp, #2, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT: ld1b { z0.b, z8.b }, pn8/z, [x0]
-; CHECK-NEXT: ld1b { z1.b, z9.b }, pn8/z, [x0, x1]
-; CHECK-NEXT: sdot za.s[w8, 0, vgx2], { z0.b, z1.b }, z0.b[0]
-; CHECK-NEXT: sdot za.s[w8, 0, vgx2], { z8.b, z9.b }, z0.b[0]
-; CHECK-NEXT: ldr z9, [sp, #1, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT: ldr z8, [sp, #2, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: str z11, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: ptrue p0.b
+; CHECK-NEXT: str z10, [sp, #2, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT: ld1b { z2.b, z10.b }, pn8/z, [x0]
+; CHECK-NEXT: ld1b { z3.b, z11.b }, pn8/z, [x0, x1]
+; CHECK-NEXT: sdot za.s[w8, 0, vgx2], { z2.b, z3.b }, z0.b[0]
+; CHECK-NEXT: sdot za.s[w8, 0, vgx2], { z10.b, z11.b }, z0.b[0]
+; CHECK-NEXT: ldr z11, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT: ldr z10, [sp, #2, mul vl] // 16-byte Folded Reload
; CHECK-NEXT: ldr p8, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT: st1b { z0.b }, p0, [x0]
; CHECK-NEXT: addvl sp, sp, #3
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
@@ -1136,6 +1150,7 @@ entry:
%6 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %4, 1
tail call void @llvm.aarch64.sme.sdot.lane.za32.vg1x2.nxv16i8(i32 0, <vscale x 16 x i8> %2, <vscale x 16 x i8> %5, <vscale x 16 x i8> poison, i32 0)
tail call void @llvm.aarch64.sme.sdot.lane.za32.vg1x2.nxv16i8(i32 0, <vscale x 16 x i8> %3, <vscale x 16 x i8> %6, <vscale x 16 x i8> poison, i32 0)
+ store <vscale x 16 x i8> %scalable_arg, ptr %ptr
ret void
}
@@ -1197,33 +1212,35 @@ define void @sdot_form_4x_tuple_svecc(ptr %ptr, i64 %stride, <vscale x 16 x i8>
; CHECK-NEXT: lsl x9, x1, #1
; CHECK-NEXT: str p8, [sp, #7, mul vl] // 2-byte Folded Spill
; CHECK-NEXT: ptrue pn8.b...
[truncated]
|
|
||
... | ||
--- | ||
name: form_4x_tuple_many_live |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably a lot of these fields are not necessary for this test (the LLVM IR above is probably also not required), could you clean this test up a bit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for cleaning these up!
%26:zpr4mul4 = FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO %4.zsub0, %9.zsub0, %15.zsub0, %21.zsub0 | ||
$za = UDOT_VG4_M4ZZI_BtoS $za, %27, 0, %26, undef %28:zpr_4b, 0 | ||
%29:zpr4mul4 = FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO %4.zsub1, %9.zsub1, %15.zsub1, %21.zsub1 | ||
$za = UDOT_VG4_M4ZZI_BtoS $za, %27, 0, %29, undef %30:zpr_4b, 0 | ||
%31:zpr4mul4 = FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO %4.zsub2, %9.zsub2, %15.zsub2, %21.zsub2 | ||
%35:ppr_3b = PTRUE_B 31, implicit $vg | ||
$za = UDOT_VG4_M4ZZI_BtoS $za, %27, 0, %31, undef %32:zpr_4b, 0 | ||
%33:zpr4mul4 = FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO %4.zsub3, %9.zsub3, %15.zsub3, %21.zsub3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you do some manual re-scheduling of these instructions, to make the test easier to read?
if (StridedOrder[I] == FirstLoadPhysReg && | ||
(I + (OpIdx - 1) < StridedOrder.size())) | ||
Hints.push_back(StridedOrder[I + (OpIdx - 1)]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you can make the assumption that in the array Order
, z0_z4_z18_z12
is followed by z1_z5_z9_z13
, because if z1
, 5
, z9
or z13
is already required for something else, the next register in that list may be z2_z6_z10_z14
. I think what you want to do instead is calculate the expected register and then add it to Hints iff it is contained in Order
.
any_of(MRI.use_nodbg_instructions(VirtReg), [&TupleInst]( | ||
const MachineInstr &Use) { | ||
bool IsTuple = | ||
Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO || | ||
Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO; | ||
TupleInst = &Use; | ||
return IsTuple; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This any_of
causes it to pick any FORM_TRANSPOSED_* use at random and only consider that one to analyse further. If it then doesn't match the criteria for adding a hint, it doesn't continue looking.
Instead, we should iterate through all uses and only if it finds any FORM_TRANSPOSED_* pseudo that matches the criteria, it should add the hint and then stop looking.
int OpIdx = TupleInst->findRegisterUseOperandIdx(VirtReg, this); | ||
if (OpIdx == -1) | ||
return TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints, | ||
MF, VRM); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be an assert instead? (I'd expect it to never be -1), otherwise this conditional exit should be moved above the loop that fills StridedOrder
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be an assert, yes. We should only ever reach this point if a use of VirtReg
was already found in TupleInst
if (CanAssign) | ||
Hints.push_back(Reg); | ||
} | ||
} else if (VRM->hasPhys(TupleInst->getOperand(1).getReg())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like mentioned above, this is making an assumption about the order in which the register allocator assigns registers, which I don't think is an assumption you can make. Can you write this case in such a way that it looks for any operand that has a physreg assigned, and works backward from there?
bool CanAssign = true; | ||
for (unsigned Next = 1; Next < LdOps; ++Next) { | ||
// Ensure we can assign enough registers from the list for all loads. | ||
if (I + Next >= StridedOrder.size()) { | ||
CanAssign = false; | ||
break; | ||
} | ||
// Ensure the subsequent registers are not live and that the starting | ||
// sub-registers are sequential. | ||
MCPhysReg NextReg = StridedOrder[I + Next]; | ||
if (Matrix->isPhysRegUsed(NextReg) || | ||
(getSubReg(NextReg, AArch64::zsub0) != FirstReg + Next)) { | ||
CanAssign = false; | ||
break; | ||
} | ||
} | ||
if (CanAssign) | ||
Hints.push_back(Reg); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This loop can be written more compactly doing something along the lines of:
for (unsigned Next = 1; Next < LdOps; ++Next) {
if (!is_contained(StridedOrder, Reg + Next))
// cannot assign
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think checking is_contained
for Reg + Next
still makes assumptions about the order of the StridedOrder list. Instead I've rewritten this to find a strided register which matches FirstReg + Next
and discards the current Reg
if none was found.
bool IsMulZPR = TupleID == AArch64::ZPR2Mul2RegClassID || | ||
TupleID == AArch64::ZPR4Mul4RegClassID; | ||
|
||
if (OpIdx == 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is making an assumption about the order in which the register allocator visits the operands and allocates registers. I think that instead you want to distinguish between the cases where:
- None of the operands for Use are allocated yet => then check if for each of the operands, we have consecutive tuple registers available, and if so then add the hint for the given OpIdx.
- Any of the operands for Use is already allocated => find the corresponding tuple reg for the current operand index, if available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've rewritten this as suggested so that we now look for hints where all consecutive tuple registers are available if none of the operands have been allocated yet.
There is a new test in sme2-multivec-regalloc.mir
to check that we still suggest the correct hints if the first load instruction to be allocated is not the first operand.
unsigned AssignedOp = 0; | ||
if (!any_of(make_range(Use.operands_begin() + 1, Use.operands_end()), | ||
[&](const MachineOperand &Op) { | ||
if (!VRM->hasPhys(Op.getReg())) | ||
return false; | ||
AssignedOp = Op.getOperandNo(); | ||
return true; | ||
})) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks a little odd. If it's doing a search for an operand that has its register allocated, then I'd suggest using llvm::find_if
to find the operand:
const MachineOperand *AssignedRegOp = llvm::find_if(
make_range(Use.operands_begin() + 1, Use.operands_end()),
[&VRM](const MachineOperand &Op) { return VRM->hasPhys(Op.getReg()); });
and then:
if (AssignedRegOp != Use.operands_end()) {
..
} else {
..
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this is easier to follow!
// Skip this register if it has any live intervals assigned. | ||
if (Matrix->isPhysRegUsed(Reg)) | ||
continue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this is redundant, because it would not have been suggested as a free register in Order
if it was already allocated for something else.
MCPhysReg RegToAssign = Reg; | ||
for (unsigned Next = 1; Next < LdOps; ++Next) { | ||
MCPhysReg Strided = GetRegStartingAt(FirstReg + Next); | ||
if (Strided == AArch64::NoRegister || | ||
Matrix->isPhysRegUsed(Strided)) { | ||
RegToAssign = AArch64::NoRegister; | ||
break; | ||
} | ||
if (Next == (unsigned)OpIdx - 1) | ||
RegToAssign = Strided; | ||
} | ||
if (RegToAssign != AArch64::NoRegister) | ||
Hints.push_back(RegToAssign); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code is a little convoluted. I think you could also avoid the extra nested loop in GetRegStartingAt
, by doing the following:
If Reg is e.g. Z1_Z5_z9_z13
, then loop from Z0_Z4_Z8_Z12
..Z3_Z7_Z11_Z15
and check if any of them has allocated a phys reg. If not, then you can add Z1_Z5_Z9_Z13
.
SmallVector<MCPhysReg> Regs;
unsigned FirstReg = Reg - OpIdx + 1;
for (unsigned I = 0; I < LdOps; ++I)
Regs.push_back(FirstReg + I);
if (all_of(Regs,
[&](MCPhysReg R) { return !Matrix->isPhysRegUsed(R); }))
Hints.push_back(FirstReg + OpIdx - 1);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've rewritten this as suggested, however I included !is_contained(StridedOrder, FirstReg + I)
in the first loop because I think it could be possible for this register to exist outside of the list. It's for this reason that I've also left in the filtering by start register here, as I need StridedOrder to contain every possible strided register for the is_contained
.
I've also added a check to make sure that the starting registers are consecutive in the loop and created a test where this is necessary in sme2-multivec-regalloc.mir.
RegID == AArch64::ZPR2StridedOrContiguousRegClassID | ||
? &AArch64::ZPR2StridedRegClass | ||
: &AArch64::ZPR4StridedRegClass; | ||
for (const MachineInstr &Use : MRI.use_nodbg_instructions(VirtReg)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code here seems to be looping through uses of VirtReg, but it does that regardless of it's register class or feature flags.
What about returning early from the function if the function has no SME or is not in streaming-mode?
And also moving this loop inside a if (RegID == .. || RegID == ...)
condition?
|
||
// If the FORM_TRANSPOSE nodes use the ZPRMul classes, the starting | ||
// register of the first load should be a multiple of 2 or 4. | ||
if (IsMulZPR && (FirstReg - AArch64::Z0) % LdOps != 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you move this filtering to where it populates the StridedOrder
vector?
|
||
... | ||
--- | ||
name: form_4x_tuple_many_live |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for cleaning these up!
liveins: | ||
- { reg: '$x0', virtual-reg: '%0' } | ||
- { reg: '$x1', virtual-reg: '%1' } | ||
- { reg: '$z0', virtual-reg: '%2' } | ||
- { reg: '$z17', virtual-reg: '%3' } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think you can remove these as well, because this information will be otherwise recomputed.
…oads. getRegAllocationHints looks for ZPR2StridedOrContiguous load instructions which are used by FORM_TRANSPOSED_REG_TUPLE pseudos and adds all strided registers from this class to the list of hints. This patch changes getRegAllocationHints to restrict this list: - If the pseudo uses ZPRMul class, the first load must begin with a register which is a multiple of 2 or 4. - Only add a strided register to the list if it does not already have any live intervals.
…TRANSPOSED pseudo, find hints based on whether any operands have been allocated yet. - Remove any_of which finds FORM_TRANSPOSED uses of VirtReg & instead iterate through all uses. - Clean up sme2-multivec-regalloc.mir & add new test which changes the allocation order of the load instructions.
- Add early exit if the function does not have SME or is not in streaming mode - Move loop over uses of VirtReg into if block checking the RegID - Remove GetRegStartingAt
b6db334
to
81c3d47
Compare
Use.getOpcode() != AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO) | ||
continue; | ||
|
||
unsigned LdOps = Use.getNumOperands() - 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just realised that LdOps
is a misnomer, because it is assigning the number of operands in the use (form_reg_tuple). If the use here has 4 operands, it could still be that the load has 2, e.g.
ld1 { z0, z8 }, p0/z, [...]
ld1 { z1, z9 }, p0/z, [...]
ld1 { z2, z10 }, p0/z, [...]
ld1 { z3, z11 }, p0/z, [...]
{z0, z1, z2, z3} = form_reg_tuple {z0, z8}:0, {z1, z9}:0, {z2, z10}:0, {z3, z11}:0
The uses below assume this is about the number of operands of the Use, so it seems like it's just the name that's wrong.
The same is not true for StridedRC
, which uses the wrong register class (it should decided the strided RC based on RegID
instead)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the last commit I did change StridedRC
to be based on RegID
; this was to support the case where the sizes of the multi-vectors used by the loads and the pseudo are different as you've described here. I've now also renamed LdOps
to UseOps
.
for (unsigned I = 0; I < StridedOrder.size(); ++I) { | ||
MCPhysReg Reg = StridedOrder[I]; | ||
SmallVector<MCPhysReg> Regs; | ||
unsigned FirstStridedReg = Reg - OpIdx + 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would avoid doing this, because Reg - OpIdx + 1
may not be an SVE tuple register, which means that getSubReg(FirstStridedReg, AArch64::zsub0)
might fail.
Example, if the first tuple register in the list would be Z0_Z1
and we're looking at the second operand in the tuple form_*tuple pseudo, i.e. OpIdx = 2, then FirstStridedReg
would be X26_X27
.
You can instead write this as:
unsigned SubRegIdx = Use.getOperand(OpIdx).getSubReg();
if (IsMulZPR && (getSubReg(Reg, SubRegIdx) - AArch64::Z0) % UseOps !=
((unsigned)OpIdx - 1))
continue;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I added the is_contained
later to make sure getSubReg
would not fail but missed that it could also fail here.
for (unsigned Op = 0; Op < LdOps; ++Op) { | ||
if (!is_contained(StridedOrder, FirstStridedReg + Op) || | ||
getSubReg(FirstStridedReg + Op, AArch64::zsub0) != | ||
FirstSubReg + Op) | ||
break; | ||
Regs.push_back(FirstStridedReg + Op); | ||
} | ||
|
||
if (Regs.size() == LdOps && all_of(Regs, [&](MCPhysReg R) { | ||
return !Matrix->isPhysRegUsed(R); | ||
})) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: This could be rewritten in such a way that it doesn't need SmallVector<MCPhysReg> Regs
as an intermediate step, e.g.
auto IsFreeConsecutiveRegs = [&](unsigned I) {
// conditions
};
if (all_of(iota_range<unsigned>(0U, UseOps, /*Inclusive=*/false),
IsFreeConsecutiveReg))
Hints.push_back(Reg);
return VRM->hasPhys(Op.getReg()); | ||
}); | ||
|
||
if (AssignedRegOp == Use.operands_end()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think an example would be useful here.
- Rewrote the case where no registers are already assigned to avoid creating the Regs vector. - Added more comments with an example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with nit addressed.
Would it make sense to update the title/commit message to reflect that you've also made it less strict? i.e. extracting one column from 4 x 2-vector loads can be used with an instruction that requires sequential regs now.
@@ -1099,6 +1100,11 @@ bool AArch64RegisterInfo::getRegAllocationHints( | |||
const VirtRegMap *VRM, const LiveRegMatrix *Matrix) const { | |||
const MachineRegisterInfo &MRI = MF.getRegInfo(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this can be moved down (after the if(..) condition)
Thanks for approving this! I've reworded the title a bit and changed the commit message to include allowing x2 & x4 multivector loads and intrinsics. |
getRegAllocationHints looks for ZPR2StridedOrContiguous load instructions
which are used by FORM_TRANSPOSED_REG_TUPLE pseudos and adds all
strided registers from this class to the list of hints.
This patch changes getRegAllocationHints to restrict this list:
which is a multiple of 2 or 4.
have any live intervals.
This also contains changes to suggest hints when the load instructions and
the FORM_TRANSPOSED pseudo use multi-vectors of different lengths,
e.g. a pseudo with a 4-vector sequence of registers formed of one column
extracted from four 2-vector loads.