Skip to content

[AArch64][SME] Make getRegAllocationHints more specific for multi-vector loads #123081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 30, 2025

Conversation

kmclaughlin-arm
Copy link
Contributor

@kmclaughlin-arm kmclaughlin-arm commented Jan 15, 2025

getRegAllocationHints looks for ZPR2StridedOrContiguous load instructions
which are used by FORM_TRANSPOSED_REG_TUPLE pseudos and adds all
strided registers from this class to the list of hints.
This patch changes getRegAllocationHints to restrict this list:

  • If the pseudo uses ZPRMul class, the first load must begin with a register
    which is a multiple of 2 or 4.
  • Only add a hint if it is part of a sequence of registers that do not already
    have any live intervals.

This also contains changes to suggest hints when the load instructions and
the FORM_TRANSPOSED pseudo use multi-vectors of different lengths,
e.g. a pseudo with a 4-vector sequence of registers formed of one column
extracted from four 2-vector loads.

@llvmbot
Copy link
Member

llvmbot commented Jan 15, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Kerry McLaughlin (kmclaughlin-arm)

Changes

getRegAllocationHints looks for ZPR2StridedOrContiguous load instructions
which are used by FORM_TRANSPOSED_REG_TUPLE pseudos and adds all
strided registers from this class to the list of hints.
This patch changes getRegAllocationHints to restrict this list:

  • If the pseudo uses ZPRMul class, the first load must begin with a register
    which is a multiple of 2 or 4.
  • Only add a strided register to the list if it does not already have any
    live intervals.

Patch is 64.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123081.diff

4 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp (+70-9)
  • (modified) llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll (+152-128)
  • (modified) llvm/test/CodeGen/AArch64/sme2-intrinsics-vdot.ll (+136-118)
  • (added) llvm/test/CodeGen/AArch64/sme2-multivec-regalloc.mir (+183)
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
index 5973b63b5a8024..aac1dc9cb5c062 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
@@ -20,6 +20,7 @@
 #include "MCTargetDesc/AArch64InstPrinter.h"
 #include "llvm/ADT/BitVector.h"
 #include "llvm/BinaryFormat/Dwarf.h"
+#include "llvm/CodeGen/LiveRegMatrix.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
@@ -1107,23 +1108,83 @@ bool AArch64RegisterInfo::getRegAllocationHints(
   // FORM_TRANSPOSED_REG_TUPLE pseudo, we want to favour reducing copy
   // instructions over reducing the number of clobbered callee-save registers,
   // so we add the strided registers as a hint.
+  const MachineInstr *TupleInst = nullptr;
   unsigned RegID = MRI.getRegClass(VirtReg)->getID();
   // Look through uses of the register for FORM_TRANSPOSED_REG_TUPLE.
   if ((RegID == AArch64::ZPR2StridedOrContiguousRegClassID ||
        RegID == AArch64::ZPR4StridedOrContiguousRegClassID) &&
-      any_of(MRI.use_nodbg_instructions(VirtReg), [](const MachineInstr &Use) {
-        return Use.getOpcode() ==
-                   AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO ||
-               Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO;
+      any_of(MRI.use_nodbg_instructions(VirtReg), [&TupleInst](
+                                                      const MachineInstr &Use) {
+        bool IsTuple =
+            Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO ||
+            Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO;
+        TupleInst = &Use;
+        return IsTuple;
       })) {
-    const TargetRegisterClass *StridedRC =
-        RegID == AArch64::ZPR2StridedOrContiguousRegClassID
-            ? &AArch64::ZPR2StridedRegClass
-            : &AArch64::ZPR4StridedRegClass;
+    unsigned LdOps = TupleInst->getNumOperands() - 1;
+    const TargetRegisterClass *StridedRC = LdOps == 2
+                                               ? &AArch64::ZPR2StridedRegClass
+                                               : &AArch64::ZPR4StridedRegClass;
 
+    SmallVector<MCPhysReg, 4> StridedOrder;
     for (MCPhysReg Reg : Order)
       if (StridedRC->contains(Reg))
-        Hints.push_back(Reg);
+        StridedOrder.push_back(Reg);
+
+    int OpIdx = TupleInst->findRegisterUseOperandIdx(VirtReg, this);
+    if (OpIdx == -1)
+      return TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints,
+                                                       MF, VRM);
+
+    unsigned TupleID =
+        MRI.getRegClass(TupleInst->getOperand(0).getReg())->getID();
+    bool IsMulZPR = TupleID == AArch64::ZPR2Mul2RegClassID ||
+                    TupleID == AArch64::ZPR4Mul4RegClassID;
+
+    if (OpIdx == 1) {
+      for (unsigned I = 0; I < StridedOrder.size(); ++I) {
+        MCPhysReg Reg = StridedOrder[I];
+        unsigned FirstReg = getSubReg(Reg, AArch64::zsub0);
+
+        // If the FORM_TRANSPOSE nodes use the ZPRMul classes, the starting
+        // register of the first load should be a multiple of 2 or 4.
+        if (IsMulZPR &&
+            (getSubReg(Reg, AArch64::zsub0) - AArch64::Z0) % LdOps != 0)
+          continue;
+        // Skip this register if it has any live intervals assigned.
+        if (Matrix->isPhysRegUsed(Reg))
+          continue;
+
+        bool CanAssign = true;
+        for (unsigned Next = 1; Next < LdOps; ++Next) {
+          // Ensure we can assign enough registers from the list for all loads.
+          if (I + Next >= StridedOrder.size()) {
+            CanAssign = false;
+            break;
+          }
+          // Ensure the subsequent registers are not live and that the starting
+          // sub-registers are sequential.
+          MCPhysReg NextReg = StridedOrder[I + Next];
+          if (Matrix->isPhysRegUsed(NextReg) ||
+              (getSubReg(NextReg, AArch64::zsub0) != FirstReg + Next)) {
+            CanAssign = false;
+            break;
+          }
+        }
+        if (CanAssign)
+          Hints.push_back(Reg);
+      }
+    } else if (VRM->hasPhys(TupleInst->getOperand(1).getReg())) {
+      // This is not the first load in the sequence. Find the register
+      // assigned to the first and match to a strided reg in the list.
+      MCPhysReg FirstLoadPhysReg =
+          VRM->getPhys(TupleInst->getOperand(1).getReg());
+      for (unsigned I = 0; I < StridedOrder.size(); ++I) {
+        if (StridedOrder[I] == FirstLoadPhysReg &&
+            (I + (OpIdx - 1) < StridedOrder.size()))
+          Hints.push_back(StridedOrder[I + (OpIdx - 1)]);
+      }
+    }
 
     return TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints, MF,
                                                      VRM);
diff --git a/llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll b/llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll
index 86ed63d743713c..4abc28c4863421 100644
--- a/llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll
+++ b/llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll
@@ -615,15 +615,17 @@ define void @udot_form_2x_tuple_svecc(ptr %ptr, i64 %stride, <vscale x 16 x i8>
 ; CHECK-NEXT:    str p8, [sp, #7, mul vl] // 2-byte Folded Spill
 ; CHECK-NEXT:    ptrue pn8.b
 ; CHECK-NEXT:    mov w8, wzr
-; CHECK-NEXT:    str z9, [sp, #1, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT:    str z8, [sp, #2, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT:    ld1b { z0.b, z8.b }, pn8/z, [x0]
-; CHECK-NEXT:    ld1b { z1.b, z9.b }, pn8/z, [x0, x1]
-; CHECK-NEXT:    udot za.s[w8, 0, vgx2], { z0.b, z1.b }, z0.b[0]
-; CHECK-NEXT:    udot za.s[w8, 0, vgx2], { z8.b, z9.b }, z0.b[0]
-; CHECK-NEXT:    ldr z9, [sp, #1, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr z8, [sp, #2, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    str z11, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    ptrue p0.b
+; CHECK-NEXT:    str z10, [sp, #2, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    ld1b { z2.b, z10.b }, pn8/z, [x0]
+; CHECK-NEXT:    ld1b { z3.b, z11.b }, pn8/z, [x0, x1]
+; CHECK-NEXT:    udot za.s[w8, 0, vgx2], { z2.b, z3.b }, z0.b[0]
+; CHECK-NEXT:    udot za.s[w8, 0, vgx2], { z10.b, z11.b }, z0.b[0]
+; CHECK-NEXT:    ldr z11, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z10, [sp, #2, mul vl] // 16-byte Folded Reload
 ; CHECK-NEXT:    ldr p8, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    st1b { z0.b }, p0, [x0]
 ; CHECK-NEXT:    addvl sp, sp, #3
 ; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
 ; CHECK-NEXT:    ret
@@ -638,6 +640,7 @@ entry:
   %6 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %4, 1
   tail call void @llvm.aarch64.sme.udot.lane.za32.vg1x2.nxv16i8(i32 0, <vscale x 16 x i8> %2, <vscale x 16 x i8> %5, <vscale x 16 x i8> poison, i32 0)
   tail call void @llvm.aarch64.sme.udot.lane.za32.vg1x2.nxv16i8(i32 0, <vscale x 16 x i8> %3, <vscale x 16 x i8> %6, <vscale x 16 x i8> poison, i32 0)
+  store <vscale x 16 x i8> %scalable_arg, ptr %ptr
   ret void
 }
 
@@ -699,33 +702,35 @@ define void @udot_form_4x_tuple_svecc(ptr %ptr, i64 %stride, <vscale x 16 x i8>
 ; CHECK-NEXT:    lsl x9, x1, #1
 ; CHECK-NEXT:    str p8, [sp, #7, mul vl] // 2-byte Folded Spill
 ; CHECK-NEXT:    ptrue pn8.b
-; CHECK-NEXT:    str z15, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z23, [sp, #1, mul vl] // 16-byte Folded Spill
 ; CHECK-NEXT:    mov w8, wzr
-; CHECK-NEXT:    str z14, [sp, #2, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    ptrue p0.b
+; CHECK-NEXT:    str z22, [sp, #2, mul vl] // 16-byte Folded Spill
 ; CHECK-NEXT:    add x10, x9, x1
-; CHECK-NEXT:    str z13, [sp, #3, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT:    str z12, [sp, #4, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT:    str z11, [sp, #5, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT:    str z10, [sp, #6, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT:    str z9, [sp, #7, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT:    str z8, [sp, #8, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT:    ld1b { z0.b, z4.b, z8.b, z12.b }, pn8/z, [x0]
-; CHECK-NEXT:    ld1b { z1.b, z5.b, z9.b, z13.b }, pn8/z, [x0, x1]
-; CHECK-NEXT:    ld1b { z2.b, z6.b, z10.b, z14.b }, pn8/z, [x0, x9]
-; CHECK-NEXT:    ld1b { z3.b, z7.b, z11.b, z15.b }, pn8/z, [x0, x10]
-; CHECK-NEXT:    udot za.s[w8, 0, vgx4], { z0.b - z3.b }, z0.b[0]
-; CHECK-NEXT:    udot za.s[w8, 0, vgx4], { z4.b - z7.b }, z0.b[0]
-; CHECK-NEXT:    udot za.s[w8, 0, vgx4], { z8.b - z11.b }, z0.b[0]
-; CHECK-NEXT:    udot za.s[w8, 0, vgx4], { z12.b - z15.b }, z0.b[0]
-; CHECK-NEXT:    ldr z15, [sp, #1, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr z14, [sp, #2, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr z13, [sp, #3, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr z12, [sp, #4, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr z11, [sp, #5, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr z10, [sp, #6, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr z9, [sp, #7, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr z8, [sp, #8, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    str z21, [sp, #3, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z20, [sp, #4, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z19, [sp, #5, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z18, [sp, #6, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z17, [sp, #7, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z16, [sp, #8, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    ld1b { z16.b, z20.b, z24.b, z28.b }, pn8/z, [x0]
+; CHECK-NEXT:    ld1b { z17.b, z21.b, z25.b, z29.b }, pn8/z, [x0, x1]
+; CHECK-NEXT:    ld1b { z18.b, z22.b, z26.b, z30.b }, pn8/z, [x0, x9]
+; CHECK-NEXT:    ld1b { z19.b, z23.b, z27.b, z31.b }, pn8/z, [x0, x10]
+; CHECK-NEXT:    udot za.s[w8, 0, vgx4], { z16.b - z19.b }, z0.b[0]
+; CHECK-NEXT:    udot za.s[w8, 0, vgx4], { z20.b - z23.b }, z0.b[0]
+; CHECK-NEXT:    udot za.s[w8, 0, vgx4], { z24.b - z27.b }, z0.b[0]
+; CHECK-NEXT:    udot za.s[w8, 0, vgx4], { z28.b - z31.b }, z0.b[0]
+; CHECK-NEXT:    ldr z23, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z22, [sp, #2, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z21, [sp, #3, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z20, [sp, #4, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z19, [sp, #5, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z18, [sp, #6, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z17, [sp, #7, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z16, [sp, #8, mul vl] // 16-byte Folded Reload
 ; CHECK-NEXT:    ldr p8, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    st1b { z0.b }, p0, [x0]
 ; CHECK-NEXT:    addvl sp, sp, #9
 ; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
 ; CHECK-NEXT:    ret
@@ -760,6 +765,7 @@ entry:
   tail call void @llvm.aarch64.sme.udot.lane.za32.vg1x4.nxv16i8(i32 0, <vscale x 16 x i8> %3, <vscale x 16 x i8> %8, <vscale x 16 x i8> %13, <vscale x 16 x i8> %18, <vscale x 16 x i8> poison, i32 0)
   tail call void @llvm.aarch64.sme.udot.lane.za32.vg1x4.nxv16i8(i32 0, <vscale x 16 x i8> %4, <vscale x 16 x i8> %9, <vscale x 16 x i8> %14, <vscale x 16 x i8> %19, <vscale x 16 x i8> poison, i32 0)
   tail call void @llvm.aarch64.sme.udot.lane.za32.vg1x4.nxv16i8(i32 0, <vscale x 16 x i8> %5, <vscale x 16 x i8> %10, <vscale x 16 x i8> %15, <vscale x 16 x i8> %20, <vscale x 16 x i8> poison, i32 0)
+  store <vscale x 16 x i8> %scalable_arg, ptr %ptr
   ret void
 }
 
@@ -863,15 +869,17 @@ define void @usdot_form_2x_tuple_svecc(ptr %ptr, i64 %stride, <vscale x 16 x i8>
 ; CHECK-NEXT:    str p8, [sp, #7, mul vl] // 2-byte Folded Spill
 ; CHECK-NEXT:    ptrue pn8.b
 ; CHECK-NEXT:    mov w8, wzr
-; CHECK-NEXT:    str z9, [sp, #1, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT:    str z8, [sp, #2, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT:    ld1b { z0.b, z8.b }, pn8/z, [x0]
-; CHECK-NEXT:    ld1b { z1.b, z9.b }, pn8/z, [x0, x1]
-; CHECK-NEXT:    usdot za.s[w8, 0, vgx2], { z0.b, z1.b }, z0.b[0]
-; CHECK-NEXT:    usdot za.s[w8, 0, vgx2], { z8.b, z9.b }, z0.b[0]
-; CHECK-NEXT:    ldr z9, [sp, #1, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr z8, [sp, #2, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    str z11, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    ptrue p0.b
+; CHECK-NEXT:    str z10, [sp, #2, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    ld1b { z2.b, z10.b }, pn8/z, [x0]
+; CHECK-NEXT:    ld1b { z3.b, z11.b }, pn8/z, [x0, x1]
+; CHECK-NEXT:    usdot za.s[w8, 0, vgx2], { z2.b, z3.b }, z0.b[0]
+; CHECK-NEXT:    usdot za.s[w8, 0, vgx2], { z10.b, z11.b }, z0.b[0]
+; CHECK-NEXT:    ldr z11, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z10, [sp, #2, mul vl] // 16-byte Folded Reload
 ; CHECK-NEXT:    ldr p8, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    st1b { z0.b }, p0, [x0]
 ; CHECK-NEXT:    addvl sp, sp, #3
 ; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
 ; CHECK-NEXT:    ret
@@ -886,6 +894,7 @@ entry:
   %6 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %4, 1
   tail call void @llvm.aarch64.sme.usdot.lane.za32.vg1x2.nxv16i8(i32 0, <vscale x 16 x i8> %2, <vscale x 16 x i8> %5, <vscale x 16 x i8> poison, i32 0)
   tail call void @llvm.aarch64.sme.usdot.lane.za32.vg1x2.nxv16i8(i32 0, <vscale x 16 x i8> %3, <vscale x 16 x i8> %6, <vscale x 16 x i8> poison, i32 0)
+  store <vscale x 16 x i8> %scalable_arg, ptr %ptr
   ret void
 }
 
@@ -947,33 +956,35 @@ define void @usdot_form_4x_tuple_svecc(ptr %ptr, i64 %stride, <vscale x 16 x i8>
 ; CHECK-NEXT:    lsl x9, x1, #1
 ; CHECK-NEXT:    str p8, [sp, #7, mul vl] // 2-byte Folded Spill
 ; CHECK-NEXT:    ptrue pn8.b
-; CHECK-NEXT:    str z15, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z23, [sp, #1, mul vl] // 16-byte Folded Spill
 ; CHECK-NEXT:    mov w8, wzr
-; CHECK-NEXT:    str z14, [sp, #2, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    ptrue p0.b
+; CHECK-NEXT:    str z22, [sp, #2, mul vl] // 16-byte Folded Spill
 ; CHECK-NEXT:    add x10, x9, x1
-; CHECK-NEXT:    str z13, [sp, #3, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT:    str z12, [sp, #4, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT:    str z11, [sp, #5, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT:    str z10, [sp, #6, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT:    str z9, [sp, #7, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT:    str z8, [sp, #8, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT:    ld1b { z0.b, z4.b, z8.b, z12.b }, pn8/z, [x0]
-; CHECK-NEXT:    ld1b { z1.b, z5.b, z9.b, z13.b }, pn8/z, [x0, x1]
-; CHECK-NEXT:    ld1b { z2.b, z6.b, z10.b, z14.b }, pn8/z, [x0, x9]
-; CHECK-NEXT:    ld1b { z3.b, z7.b, z11.b, z15.b }, pn8/z, [x0, x10]
-; CHECK-NEXT:    usdot za.s[w8, 0, vgx4], { z0.b - z3.b }, z0.b[0]
-; CHECK-NEXT:    usdot za.s[w8, 0, vgx4], { z4.b - z7.b }, z0.b[0]
-; CHECK-NEXT:    usdot za.s[w8, 0, vgx4], { z8.b - z11.b }, z0.b[0]
-; CHECK-NEXT:    usdot za.s[w8, 0, vgx4], { z12.b - z15.b }, z0.b[0]
-; CHECK-NEXT:    ldr z15, [sp, #1, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr z14, [sp, #2, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr z13, [sp, #3, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr z12, [sp, #4, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr z11, [sp, #5, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr z10, [sp, #6, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr z9, [sp, #7, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr z8, [sp, #8, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    str z21, [sp, #3, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z20, [sp, #4, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z19, [sp, #5, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z18, [sp, #6, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z17, [sp, #7, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z16, [sp, #8, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    ld1b { z16.b, z20.b, z24.b, z28.b }, pn8/z, [x0]
+; CHECK-NEXT:    ld1b { z17.b, z21.b, z25.b, z29.b }, pn8/z, [x0, x1]
+; CHECK-NEXT:    ld1b { z18.b, z22.b, z26.b, z30.b }, pn8/z, [x0, x9]
+; CHECK-NEXT:    ld1b { z19.b, z23.b, z27.b, z31.b }, pn8/z, [x0, x10]
+; CHECK-NEXT:    usdot za.s[w8, 0, vgx4], { z16.b - z19.b }, z0.b[0]
+; CHECK-NEXT:    usdot za.s[w8, 0, vgx4], { z20.b - z23.b }, z0.b[0]
+; CHECK-NEXT:    usdot za.s[w8, 0, vgx4], { z24.b - z27.b }, z0.b[0]
+; CHECK-NEXT:    usdot za.s[w8, 0, vgx4], { z28.b - z31.b }, z0.b[0]
+; CHECK-NEXT:    ldr z23, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z22, [sp, #2, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z21, [sp, #3, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z20, [sp, #4, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z19, [sp, #5, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z18, [sp, #6, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z17, [sp, #7, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z16, [sp, #8, mul vl] // 16-byte Folded Reload
 ; CHECK-NEXT:    ldr p8, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    st1b { z0.b }, p0, [x0]
 ; CHECK-NEXT:    addvl sp, sp, #9
 ; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
 ; CHECK-NEXT:    ret
@@ -1008,6 +1019,7 @@ entry:
   tail call void @llvm.aarch64.sme.usdot.lane.za32.vg1x4.nxv16i8(i32 0, <vscale x 16 x i8> %3, <vscale x 16 x i8> %8, <vscale x 16 x i8> %13, <vscale x 16 x i8> %18, <vscale x 16 x i8> poison, i32 0)
   tail call void @llvm.aarch64.sme.usdot.lane.za32.vg1x4.nxv16i8(i32 0, <vscale x 16 x i8> %4, <vscale x 16 x i8> %9, <vscale x 16 x i8> %14, <vscale x 16 x i8> %19, <vscale x 16 x i8> poison, i32 0)
   tail call void @llvm.aarch64.sme.usdot.lane.za32.vg1x4.nxv16i8(i32 0, <vscale x 16 x i8> %5, <vscale x 16 x i8> %10, <vscale x 16 x i8> %15, <vscale x 16 x i8> %20, <vscale x 16 x i8> poison, i32 0)
+  store <vscale x 16 x i8> %scalable_arg, ptr %ptr
   ret void
 }
 
@@ -1113,15 +1125,17 @@ define void @sdot_form_2x_tuple_svecc(ptr %ptr, i64 %stride, <vscale x 16 x i8>
 ; CHECK-NEXT:    str p8, [sp, #7, mul vl] // 2-byte Folded Spill
 ; CHECK-NEXT:    ptrue pn8.b
 ; CHECK-NEXT:    mov w8, wzr
-; CHECK-NEXT:    str z9, [sp, #1, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT:    str z8, [sp, #2, mul vl] // 16-byte Folded Spill
-; CHECK-NEXT:    ld1b { z0.b, z8.b }, pn8/z, [x0]
-; CHECK-NEXT:    ld1b { z1.b, z9.b }, pn8/z, [x0, x1]
-; CHECK-NEXT:    sdot za.s[w8, 0, vgx2], { z0.b, z1.b }, z0.b[0]
-; CHECK-NEXT:    sdot za.s[w8, 0, vgx2], { z8.b, z9.b }, z0.b[0]
-; CHECK-NEXT:    ldr z9, [sp, #1, mul vl] // 16-byte Folded Reload
-; CHECK-NEXT:    ldr z8, [sp, #2, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    str z11, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    ptrue p0.b
+; CHECK-NEXT:    str z10, [sp, #2, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    ld1b { z2.b, z10.b }, pn8/z, [x0]
+; CHECK-NEXT:    ld1b { z3.b, z11.b }, pn8/z, [x0, x1]
+; CHECK-NEXT:    sdot za.s[w8, 0, vgx2], { z2.b, z3.b }, z0.b[0]
+; CHECK-NEXT:    sdot za.s[w8, 0, vgx2], { z10.b, z11.b }, z0.b[0]
+; CHECK-NEXT:    ldr z11, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z10, [sp, #2, mul vl] // 16-byte Folded Reload
 ; CHECK-NEXT:    ldr p8, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    st1b { z0.b }, p0, [x0]
 ; CHECK-NEXT:    addvl sp, sp, #3
 ; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
 ; CHECK-NEXT:    ret
@@ -1136,6 +1150,7 @@ entry:
   %6 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %4, 1
   tail call void @llvm.aarch64.sme.sdot.lane.za32.vg1x2.nxv16i8(i32 0, <vscale x 16 x i8> %2, <vscale x 16 x i8> %5, <vscale x 16 x i8> poison, i32 0)
   tail call void @llvm.aarch64.sme.sdot.lane.za32.vg1x2.nxv16i8(i32 0, <vscale x 16 x i8> %3, <vscale x 16 x i8> %6, <vscale x 16 x i8> poison, i32 0)
+  store <vscale x 16 x i8> %scalable_arg, ptr %ptr
   ret void
 }
 
@@ -1197,33 +1212,35 @@ define void @sdot_form_4x_tuple_svecc(ptr %ptr, i64 %stride, <vscale x 16 x i8>
 ; CHECK-NEXT:    lsl x9, x1, #1
 ; CHECK-NEXT:    str p8, [sp, #7, mul vl] // 2-byte Folded Spill
 ; CHECK-NEXT:    ptrue pn8.b...
[truncated]


...
---
name: form_4x_tuple_many_live
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably a lot of these fields are not necessary for this test (the LLVM IR above is probably also not required), could you clean this test up a bit?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning these up!

Comment on lines 170 to 177
%26:zpr4mul4 = FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO %4.zsub0, %9.zsub0, %15.zsub0, %21.zsub0
$za = UDOT_VG4_M4ZZI_BtoS $za, %27, 0, %26, undef %28:zpr_4b, 0
%29:zpr4mul4 = FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO %4.zsub1, %9.zsub1, %15.zsub1, %21.zsub1
$za = UDOT_VG4_M4ZZI_BtoS $za, %27, 0, %29, undef %30:zpr_4b, 0
%31:zpr4mul4 = FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO %4.zsub2, %9.zsub2, %15.zsub2, %21.zsub2
%35:ppr_3b = PTRUE_B 31, implicit $vg
$za = UDOT_VG4_M4ZZI_BtoS $za, %27, 0, %31, undef %32:zpr_4b, 0
%33:zpr4mul4 = FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO %4.zsub3, %9.zsub3, %15.zsub3, %21.zsub3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you do some manual re-scheduling of these instructions, to make the test easier to read?

Comment on lines 1183 to 1185
if (StridedOrder[I] == FirstLoadPhysReg &&
(I + (OpIdx - 1) < StridedOrder.size()))
Hints.push_back(StridedOrder[I + (OpIdx - 1)]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you can make the assumption that in the array Order, z0_z4_z18_z12 is followed by z1_z5_z9_z13, because if z1, 5, z9 or z13 is already required for something else, the next register in that list may be z2_z6_z10_z14. I think what you want to do instead is calculate the expected register and then add it to Hints iff it is contained in Order.

Comment on lines 1116 to 1122
any_of(MRI.use_nodbg_instructions(VirtReg), [&TupleInst](
const MachineInstr &Use) {
bool IsTuple =
Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO ||
Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO;
TupleInst = &Use;
return IsTuple;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This any_of causes it to pick any FORM_TRANSPOSED_* use at random and only consider that one to analyse further. If it then doesn't match the criteria for adding a hint, it doesn't continue looking.

Instead, we should iterate through all uses and only if it finds any FORM_TRANSPOSED_* pseudo that matches the criteria, it should add the hint and then stop looking.

Comment on lines 1134 to 1137
int OpIdx = TupleInst->findRegisterUseOperandIdx(VirtReg, this);
if (OpIdx == -1)
return TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints,
MF, VRM);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be an assert instead? (I'd expect it to never be -1), otherwise this conditional exit should be moved above the loop that fills StridedOrder.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be an assert, yes. We should only ever reach this point if a use of VirtReg was already found in TupleInst

if (CanAssign)
Hints.push_back(Reg);
}
} else if (VRM->hasPhys(TupleInst->getOperand(1).getReg())) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like mentioned above, this is making an assumption about the order in which the register allocator assigns registers, which I don't think is an assumption you can make. Can you write this case in such a way that it looks for any operand that has a physreg assigned, and works backward from there?

Comment on lines 1158 to 1175
bool CanAssign = true;
for (unsigned Next = 1; Next < LdOps; ++Next) {
// Ensure we can assign enough registers from the list for all loads.
if (I + Next >= StridedOrder.size()) {
CanAssign = false;
break;
}
// Ensure the subsequent registers are not live and that the starting
// sub-registers are sequential.
MCPhysReg NextReg = StridedOrder[I + Next];
if (Matrix->isPhysRegUsed(NextReg) ||
(getSubReg(NextReg, AArch64::zsub0) != FirstReg + Next)) {
CanAssign = false;
break;
}
}
if (CanAssign)
Hints.push_back(Reg);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loop can be written more compactly doing something along the lines of:

for (unsigned Next = 1; Next < LdOps; ++Next) {
  if (!is_contained(StridedOrder, Reg + Next))
    // cannot assign
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think checking is_contained for Reg + Next still makes assumptions about the order of the StridedOrder list. Instead I've rewritten this to find a strided register which matches FirstReg + Next and discards the current Reg if none was found.

bool IsMulZPR = TupleID == AArch64::ZPR2Mul2RegClassID ||
TupleID == AArch64::ZPR4Mul4RegClassID;

if (OpIdx == 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is making an assumption about the order in which the register allocator visits the operands and allocates registers. I think that instead you want to distinguish between the cases where:

  • None of the operands for Use are allocated yet => then check if for each of the operands, we have consecutive tuple registers available, and if so then add the hint for the given OpIdx.
  • Any of the operands for Use is already allocated => find the corresponding tuple reg for the current operand index, if available.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've rewritten this as suggested so that we now look for hints where all consecutive tuple registers are available if none of the operands have been allocated yet.
There is a new test in sme2-multivec-regalloc.mir to check that we still suggest the correct hints if the first load instruction to be allocated is not the first operand.

Comment on lines 1144 to 1151
unsigned AssignedOp = 0;
if (!any_of(make_range(Use.operands_begin() + 1, Use.operands_end()),
[&](const MachineOperand &Op) {
if (!VRM->hasPhys(Op.getReg()))
return false;
AssignedOp = Op.getOperandNo();
return true;
})) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a little odd. If it's doing a search for an operand that has its register allocated, then I'd suggest using llvm::find_if to find the operand:

const MachineOperand *AssignedRegOp = llvm::find_if(
    make_range(Use.operands_begin() + 1, Use.operands_end()),
    [&VRM](const MachineOperand &Op) { return VRM->hasPhys(Op.getReg()); });

and then:

if (AssignedRegOp != Use.operands_end()) {
  ..
} else {
  ..
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this is easier to follow!

Comment on lines 1162 to 1164
// Skip this register if it has any live intervals assigned.
if (Matrix->isPhysRegUsed(Reg))
continue;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is redundant, because it would not have been suggested as a free register in Order if it was already allocated for something else.

Comment on lines 1169 to 1181
MCPhysReg RegToAssign = Reg;
for (unsigned Next = 1; Next < LdOps; ++Next) {
MCPhysReg Strided = GetRegStartingAt(FirstReg + Next);
if (Strided == AArch64::NoRegister ||
Matrix->isPhysRegUsed(Strided)) {
RegToAssign = AArch64::NoRegister;
break;
}
if (Next == (unsigned)OpIdx - 1)
RegToAssign = Strided;
}
if (RegToAssign != AArch64::NoRegister)
Hints.push_back(RegToAssign);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is a little convoluted. I think you could also avoid the extra nested loop in GetRegStartingAt, by doing the following:

If Reg is e.g. Z1_Z5_z9_z13, then loop from Z0_Z4_Z8_Z12..Z3_Z7_Z11_Z15 and check if any of them has allocated a phys reg. If not, then you can add Z1_Z5_Z9_Z13.

SmallVector<MCPhysReg> Regs;
unsigned FirstReg = Reg - OpIdx + 1;
for (unsigned I = 0; I < LdOps; ++I)
  Regs.push_back(FirstReg + I);

if (all_of(Regs,
           [&](MCPhysReg R) { return !Matrix->isPhysRegUsed(R); }))
  Hints.push_back(FirstReg + OpIdx - 1);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've rewritten this as suggested, however I included !is_contained(StridedOrder, FirstReg + I) in the first loop because I think it could be possible for this register to exist outside of the list. It's for this reason that I've also left in the filtering by start register here, as I need StridedOrder to contain every possible strided register for the is_contained.

I've also added a check to make sure that the starting registers are consecutive in the loop and created a test where this is necessary in sme2-multivec-regalloc.mir.

RegID == AArch64::ZPR2StridedOrContiguousRegClassID
? &AArch64::ZPR2StridedRegClass
: &AArch64::ZPR4StridedRegClass;
for (const MachineInstr &Use : MRI.use_nodbg_instructions(VirtReg)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code here seems to be looping through uses of VirtReg, but it does that regardless of it's register class or feature flags.

What about returning early from the function if the function has no SME or is not in streaming-mode?
And also moving this loop inside a if (RegID == .. || RegID == ...) condition?


// If the FORM_TRANSPOSE nodes use the ZPRMul classes, the starting
// register of the first load should be a multiple of 2 or 4.
if (IsMulZPR && (FirstReg - AArch64::Z0) % LdOps != 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you move this filtering to where it populates the StridedOrder vector?


...
---
name: form_4x_tuple_many_live
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning these up!

Comment on lines 7 to 11
liveins:
- { reg: '$x0', virtual-reg: '%0' }
- { reg: '$x1', virtual-reg: '%1' }
- { reg: '$z0', virtual-reg: '%2' }
- { reg: '$z17', virtual-reg: '%3' }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think you can remove these as well, because this information will be otherwise recomputed.

…oads.

getRegAllocationHints looks for ZPR2StridedOrContiguous load instructions
which are used by FORM_TRANSPOSED_REG_TUPLE pseudos and adds all
strided registers from this class to the list of hints.
This patch changes getRegAllocationHints to restrict this list:
 - If the pseudo uses ZPRMul class, the first load must begin with a register
   which is a multiple of 2 or 4.
 - Only add a strided register to the list if it does not already have any
   live intervals.
…TRANSPOSED

  pseudo, find hints based on whether any operands have been allocated yet.
- Remove any_of which finds FORM_TRANSPOSED uses of VirtReg & instead iterate
  through all uses.
- Clean up sme2-multivec-regalloc.mir & add new test which changes the
  allocation order of the load instructions.
- Add early exit if the function does not have SME or is not in streaming mode
- Move loop over uses of VirtReg into if block checking the RegID
- Remove GetRegStartingAt
Use.getOpcode() != AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO)
continue;

unsigned LdOps = Use.getNumOperands() - 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realised that LdOps is a misnomer, because it is assigning the number of operands in the use (form_reg_tuple). If the use here has 4 operands, it could still be that the load has 2, e.g.

ld1 { z0, z8 }, p0/z, [...]
ld1 { z1, z9 }, p0/z, [...]
ld1 { z2, z10 }, p0/z, [...]
ld1 { z3, z11 }, p0/z, [...]
{z0, z1, z2, z3} = form_reg_tuple {z0, z8}:0, {z1, z9}:0, {z2, z10}:0, {z3, z11}:0

The uses below assume this is about the number of operands of the Use, so it seems like it's just the name that's wrong.

The same is not true for StridedRC, which uses the wrong register class (it should decided the strided RC based on RegID instead)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the last commit I did change StridedRC to be based on RegID; this was to support the case where the sizes of the multi-vectors used by the loads and the pseudo are different as you've described here. I've now also renamed LdOps to UseOps.

for (unsigned I = 0; I < StridedOrder.size(); ++I) {
MCPhysReg Reg = StridedOrder[I];
SmallVector<MCPhysReg> Regs;
unsigned FirstStridedReg = Reg - OpIdx + 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would avoid doing this, because Reg - OpIdx + 1 may not be an SVE tuple register, which means that getSubReg(FirstStridedReg, AArch64::zsub0) might fail.

Example, if the first tuple register in the list would be Z0_Z1 and we're looking at the second operand in the tuple form_*tuple pseudo, i.e. OpIdx = 2, then FirstStridedReg would be X26_X27.

You can instead write this as:

unsigned SubRegIdx = Use.getOperand(OpIdx).getSubReg();
if (IsMulZPR && (getSubReg(Reg, SubRegIdx) - AArch64::Z0) % UseOps !=
                    ((unsigned)OpIdx - 1))
  continue;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I added the is_contained later to make sure getSubReg would not fail but missed that it could also fail here.

Comment on lines 1163 to 1173
for (unsigned Op = 0; Op < LdOps; ++Op) {
if (!is_contained(StridedOrder, FirstStridedReg + Op) ||
getSubReg(FirstStridedReg + Op, AArch64::zsub0) !=
FirstSubReg + Op)
break;
Regs.push_back(FirstStridedReg + Op);
}

if (Regs.size() == LdOps && all_of(Regs, [&](MCPhysReg R) {
return !Matrix->isPhysRegUsed(R);
}))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This could be rewritten in such a way that it doesn't need SmallVector<MCPhysReg> Regs as an intermediate step, e.g.

auto IsFreeConsecutiveRegs = [&](unsigned I) {
  // conditions
};
if (all_of(iota_range<unsigned>(0U, UseOps, /*Inclusive=*/false),
           IsFreeConsecutiveReg))
  Hints.push_back(Reg);

return VRM->hasPhys(Op.getReg());
});

if (AssignedRegOp == Use.operands_end()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think an example would be useful here.

- Rewrote the case where no registers are already assigned to avoid creating
  the Regs vector.
- Added more comments with an example
Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with nit addressed.

Would it make sense to update the title/commit message to reflect that you've also made it less strict? i.e. extracting one column from 4 x 2-vector loads can be used with an instruction that requires sequential regs now.

@@ -1099,6 +1100,11 @@ bool AArch64RegisterInfo::getRegAllocationHints(
const VirtRegMap *VRM, const LiveRegMatrix *Matrix) const {
const MachineRegisterInfo &MRI = MF.getRegInfo();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this can be moved down (after the if(..) condition)

@kmclaughlin-arm kmclaughlin-arm changed the title [AArch64][SME] Make getRegAllocationHints stricter for multi-vector loads [AArch64][SME] Make getRegAllocationHints more specific for multi-vector loads Jan 30, 2025
@kmclaughlin-arm
Copy link
Contributor Author

Would it make sense to update the title/commit message to reflect that you've also made it less strict? i.e. extracting one column from 4 x 2-vector loads can be used with an instruction that requires sequential regs now.

Thanks for approving this! I've reworded the title a bit and changed the commit message to include allowing x2 & x4 multivector loads and intrinsics.

@kmclaughlin-arm kmclaughlin-arm merged commit 2fbfaff into llvm:main Jan 30, 2025
5 of 8 checks passed
@kmclaughlin-arm kmclaughlin-arm deleted the 2-strided-ld-regalloc branch February 6, 2025 13:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants