Skip to content

[AArch64][SME2] Improve register allocation of multi-vector SME intrinsics #116399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Dec 12, 2024

Conversation

kmclaughlin-arm
Copy link
Contributor

@kmclaughlin-arm kmclaughlin-arm commented Nov 15, 2024

The FORM_TRANSPOSED_REG_TUPLE pseudos have been created to
improve register allocation for intrinsics which use strided and
contiguous multi-vector registers, avoiding unnecessary copies.

If the operands of the pseudo are copies where the source register is in
the StridedOrContiguous class, the pseudo is used by getRegAllocationHints
to suggest a contigious multi-vector register which matches the subregister
sequence used by the operands.
If the operands do not match this pattern, the pseudos are expanded
to a REG_SEQUENCE.

Patch contains changes by Matthew Devereau.

This patch adds a pseudo node to help towards improving register
allocation of multi-vector SME intrinsics.

The FORM_STRIDED_TUPLE node is emitted if each of the operands of a
contiguous multi-vector dot intrinsic are the result of a strided
multi-vector load. The operands of the psuedo will be one subregister
at the same index from each of these strided loads.

Follow up patches will use this pseudo when adding register allocation
hints to remove unecessary register copies in this scenario. Subregister
liveness is also required to achieve this and has been enabled in the
tests changed by this patch.

Patch contains changes by Matthew Devereau.
@llvmbot
Copy link
Member

llvmbot commented Nov 16, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Kerry McLaughlin (kmclaughlin-arm)

Changes

This patch adds a pseudo node to help towards improving register
allocation of multi-vector SME intrinsics.

The FORM_STRIDED_TUPLE node is emitted if each of the operands of a
contiguous multi-vector dot intrinsic are the result of a strided
multi-vector load. The operands of the pseudo will be one subregister
at the same index from each of these strided loads.

Follow up patches will use this pseudo when adding register allocation
hints to remove unnecessary register copies in this scenario. Subregister
liveness is also required to achieve this and has been enabled in the
tests changed by this patch.

Patch contains changes by Matthew Devereau.


Patch is 99.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116399.diff

8 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp (+32)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp (+27)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+63)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+3)
  • (modified) llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td (+11)
  • (modified) llvm/lib/Target/AArch64/SMEInstrFormats.td (+12)
  • (modified) llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll (+447-125)
  • (modified) llvm/test/CodeGen/AArch64/sme2-intrinsics-vdot.ll (+322-34)
diff --git a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
index 055cb3cefcedf9..dabcaaf9f5c874 100644
--- a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
@@ -67,6 +67,10 @@ class AArch64ExpandPseudo : public MachineFunctionPass {
                             TargetRegisterClass ContiguousClass,
                             TargetRegisterClass StridedClass,
                             unsigned ContiguousOpc, unsigned StridedOpc);
+  bool expandFormTuplePseudo(MachineBasicBlock &MBB,
+                             MachineBasicBlock::iterator MBBI,
+                             MachineBasicBlock::iterator &NextMBBI,
+                             unsigned Size);
   bool expandMOVImm(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
                     unsigned BitSize);
 
@@ -1142,6 +1146,30 @@ bool AArch64ExpandPseudo::expandMultiVecPseudo(
   return true;
 }
 
+bool AArch64ExpandPseudo::expandFormTuplePseudo(
+    MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
+    MachineBasicBlock::iterator &NextMBBI, unsigned Size) {
+  assert(Size == 2 || Size == 4 && "Invalid Tuple Size");
+  MachineInstr &MI = *MBBI;
+  Register ReturnTuple = MI.getOperand(0).getReg();
+
+  const TargetRegisterInfo *TRI =
+      MBB.getParent()->getSubtarget().getRegisterInfo();
+  for (unsigned i = 0; i < Size; i++) {
+    Register FormTupleOpReg = MI.getOperand(i + 1).getReg();
+    Register ReturnTupleSubReg =
+        TRI->getSubReg(ReturnTuple, AArch64::zsub0 + i);
+    if (FormTupleOpReg != ReturnTupleSubReg)
+      BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(AArch64::ORR_ZZZ))
+          .addReg(ReturnTupleSubReg, RegState::Define)
+          .addReg(FormTupleOpReg)
+          .addReg(FormTupleOpReg);
+  }
+
+  MI.eraseFromParent();
+  return true;
+}
+
 /// If MBBI references a pseudo instruction that should be expanded here,
 /// do the expansion and return true.  Otherwise return false.
 bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
@@ -1724,6 +1752,10 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
      return expandMultiVecPseudo(
          MBB, MBBI, AArch64::ZPR4RegClass, AArch64::ZPR4StridedRegClass,
          AArch64::LDNT1D_4Z, AArch64::LDNT1D_4Z_STRIDED);
+   case AArch64::FORM_STRIDED_TUPLE_X2_PSEUDO:
+     return expandFormTuplePseudo(MBB, MBBI, NextMBBI, 2);
+   case AArch64::FORM_STRIDED_TUPLE_X4_PSEUDO:
+     return expandFormTuplePseudo(MBB, MBBI, NextMBBI, 4);
   }
   return false;
 }
diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 1969c830f4d312..d46bae07b3d4c5 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -504,6 +504,8 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
 
   bool SelectAllActivePredicate(SDValue N);
   bool SelectAnyPredicate(SDValue N);
+
+  void SelectFormTuplePseudo(SDNode *N, unsigned Size);
 };
 
 class AArch64DAGToDAGISelLegacy : public SelectionDAGISelLegacy {
@@ -7181,6 +7183,14 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
     }
     break;
   }
+  case AArch64ISD::FORM_STRIDED_TUPLE_X2: {
+    SelectFormTuplePseudo(Node, 2);
+    return;
+  }
+  case AArch64ISD::FORM_STRIDED_TUPLE_X4: {
+    SelectFormTuplePseudo(Node, 4);
+    return;
+  }
   }
 
   // Select the default instruction
@@ -7438,3 +7448,20 @@ bool AArch64DAGToDAGISel::SelectSMETileSlice(SDValue N, unsigned MaxSize,
   Offset = CurDAG->getTargetConstant(0, SDLoc(N), MVT::i64);
   return true;
 }
+
+void AArch64DAGToDAGISel::SelectFormTuplePseudo(SDNode *Node, unsigned Size) {
+  assert((Size == 2 || Size == 4) && "Invalid Tuple size");
+  EVT VT = Node->getValueType(0);
+  SmallVector<SDValue> Ops;
+  for (unsigned I = 0; I < Size; I++)
+    Ops.push_back(Node->getOperand(I));
+  SDLoc DL(Node);
+  unsigned Opc = Size == 2 ? AArch64::FORM_STRIDED_TUPLE_X2_PSEUDO
+                           : AArch64::FORM_STRIDED_TUPLE_X4_PSEUDO;
+  SDNode *Tuple = CurDAG->getMachineNode(Opc, DL, MVT::Untyped, Ops);
+  SDValue SuperReg = SDValue(Tuple, 0);
+  for (unsigned I = 0; I < Size; ++I)
+    ReplaceUses(SDValue(Node, I), CurDAG->getTargetExtractSubreg(
+                                      AArch64::zsub0 + I, DL, VT, SuperReg));
+  CurDAG->RemoveDeadNode(Node);
+}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9d1c3d4eddc880..b8c87b0ec2ea5f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2808,6 +2808,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::FMUL_PRED)
     MAKE_CASE(AArch64ISD::FSUB_PRED)
     MAKE_CASE(AArch64ISD::RDSVL)
+    MAKE_CASE(AArch64ISD::FORM_STRIDED_TUPLE_X2)
+    MAKE_CASE(AArch64ISD::FORM_STRIDED_TUPLE_X4)
     MAKE_CASE(AArch64ISD::BIC)
     MAKE_CASE(AArch64ISD::CBZ)
     MAKE_CASE(AArch64ISD::CBNZ)
@@ -5709,6 +5711,46 @@ SDValue AArch64TargetLowering::getRuntimePStateSM(SelectionDAG &DAG,
                      Mask);
 }
 
+static unsigned getIntrinsicID(const SDNode *N);
+
+SDValue TryLowerMultiVecSMEDotIntrinsic(SDValue Op, SelectionDAG &DAG,
+                                        unsigned Size) {
+  assert((Size == 2 || Size == 4) && "Invalid Tuple Size");
+  auto IsStridedLoad = [Size](SDValue Op) -> bool {
+    unsigned Intrinsic = getIntrinsicID(Op.getNode());
+    if (Size == 2)
+      return Intrinsic == Intrinsic::aarch64_sve_ld1_pn_x2;
+    else
+      return Intrinsic == Intrinsic::aarch64_sve_ld1_pn_x4;
+  };
+
+  SmallVector<SDValue> Ops;
+  unsigned LastLoadIdx = Size == 2 ? 5 : 7;
+  unsigned LoadResNo = Op.getOperand(3).getResNo();
+  for (unsigned I = 3; I < LastLoadIdx; I++) {
+    if (!IsStridedLoad(Op->getOperand(I)) ||
+        Op.getOperand(I).getResNo() != LoadResNo)
+      return SDValue();
+    Ops.push_back(Op->getOperand(I));
+  }
+
+  EVT VT = Op->getOperand(3).getValueType();
+  SDVTList VTList =
+      Size == 2 ? DAG.getVTList(VT, VT) : DAG.getVTList(VT, VT, VT, VT);
+  unsigned Opc = Size == 2 ? AArch64ISD::FORM_STRIDED_TUPLE_X2
+                           : AArch64ISD::FORM_STRIDED_TUPLE_X4;
+  SDLoc DL(Op);
+  SDValue Pseudo = DAG.getNode(Opc, DL, VTList, Ops);
+
+  SmallVector<SDValue> DotOps = {Op.getOperand(0), Op->getOperand(1),
+                                 Op->getOperand(2)};
+  for (unsigned I = 0; I < Size; I++)
+    DotOps.push_back(Pseudo.getValue(I));
+  DotOps.push_back(Op->getOperand(DotOps.size()));
+  DotOps.push_back(Op->getOperand(DotOps.size()));
+  return DAG.getNode(Op->getOpcode(), DL, MVT::Other, DotOps);
+}
+
 // Lower an SME LDR/STR ZA intrinsic
 // Case 1: If the vector number (vecnum) is an immediate in range, it gets
 // folded into the instruction
@@ -5898,6 +5940,22 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
         Op->getOperand(0), // Chain
         DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
         DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
+  case Intrinsic::aarch64_sme_uvdot_lane_za32_vg1x4:
+  case Intrinsic::aarch64_sme_suvdot_lane_za32_vg1x4:
+  case Intrinsic::aarch64_sme_usvdot_lane_za32_vg1x4:
+  case Intrinsic::aarch64_sme_svdot_lane_za32_vg1x4:
+  case Intrinsic::aarch64_sme_usdot_lane_za32_vg1x4:
+  case Intrinsic::aarch64_sme_udot_lane_za32_vg1x4:
+  case Intrinsic::aarch64_sme_sudot_lane_za32_vg1x4:
+  case Intrinsic::aarch64_sme_sdot_lane_za32_vg1x4:
+    return TryLowerMultiVecSMEDotIntrinsic(Op, DAG, 4);
+  case Intrinsic::aarch64_sme_uvdot_lane_za32_vg1x2:
+  case Intrinsic::aarch64_sme_sdot_lane_za32_vg1x2:
+  case Intrinsic::aarch64_sme_svdot_lane_za32_vg1x2:
+  case Intrinsic::aarch64_sme_usdot_lane_za32_vg1x2:
+  case Intrinsic::aarch64_sme_sudot_lane_za32_vg1x2:
+  case Intrinsic::aarch64_sme_udot_lane_za32_vg1x2:
+    return TryLowerMultiVecSMEDotIntrinsic(Op, DAG, 2);
   }
 }
 
@@ -7639,6 +7697,11 @@ static unsigned getIntrinsicID(const SDNode *N) {
       return IID;
     return Intrinsic::not_intrinsic;
   }
+  case ISD::INTRINSIC_W_CHAIN: {
+    unsigned IID = N->getConstantOperandVal(1);
+    if (IID < Intrinsic::num_intrinsics)
+      return IID;
+  }
   }
 }
 
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index d11da64d3f84eb..c7a70ab9f3c898 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -478,6 +478,9 @@ enum NodeType : unsigned {
   SME_ZA_LDR,
   SME_ZA_STR,
 
+  FORM_STRIDED_TUPLE_X2,
+  FORM_STRIDED_TUPLE_X4,
+
   // NEON Load/Store with post-increment base updates
   LD2post = ISD::FIRST_TARGET_MEMORY_OPCODE,
   LD3post,
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index a6ba6ddc30b277..5fb44fe5146d3c 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -28,6 +28,17 @@ def AArch64_restore_zt : SDNode<"AArch64ISD::RESTORE_ZT", SDTypeProfile<0, 2,
 def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2,
                              [SDTCisInt<0>, SDTCisPtrTy<1>]>,
                              [SDNPHasChain, SDNPSideEffect, SDNPMayStore]>;
+
+def SDT_FORM_STRIDED_TUPLE_X2 : SDTypeProfile<4, 4,
+                             [SDTCisVec<0>, SDTCisSameAs<0, 1>,
+                              SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>]>;
+
+def SDT_FORM_STRIDED_TUPLE_X4 : SDTypeProfile<4, 4,
+                             [SDTCisVec<0>, SDTCisSameAs<0, 1>,
+                              SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>,
+                              SDTCisSameAs<0, 4>, SDTCisSameAs<0, 5>,
+                              SDTCisSameAs<0, 6>, SDTCisSameAs<0, 7>]>;
+
 def AArch64CoalescerBarrier
     : SDNode<"AArch64ISD::COALESCER_BARRIER", SDTypeProfile<1, 1, []>, [SDNPOptInGlue, SDNPOutGlue]>;
 
diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td
index 8c256b5818ee88..41508bce651c6b 100644
--- a/llvm/lib/Target/AArch64/SMEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td
@@ -34,6 +34,18 @@ def tileslicerange0s4 : ComplexPattern<i32, 2, "SelectSMETileSlice<0,  4>", []>;
 
 def am_sme_indexed_b4 :ComplexPattern<iPTR, 2, "SelectAddrModeIndexedSVE<0,15>", [], [SDNPWantRoot]>;
 
+def FORM_STRIDED_TUPLE_X2_PSEUDO :
+  Pseudo<(outs ZPR2Mul2:$tup),
+         (ins ZPR:$zn0, ZPR:$zn1), []>, Sched<[]>{
+  let hasSideEffects = 0;
+}
+
+def FORM_STRIDED_TUPLE_X4_PSEUDO :
+  Pseudo<(outs ZPR4Mul4:$tup),
+         (ins ZPR:$zn0, ZPR:$zn1, ZPR:$zn2, ZPR:$zn3), []>, Sched<[]>{
+  let hasSideEffects = 0;
+}
+
 def SDTZALoadStore : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisInt<2>]>;
 def AArch64SMELdr : SDNode<"AArch64ISD::SME_ZA_LDR", SDTZALoadStore,
                              [SDNPHasChain, SDNPSideEffect, SDNPMayLoad]>;
diff --git a/llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll b/llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll
index 1e835c92ba9e4c..eddff238ace031 100644
--- a/llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll
+++ b/llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -force-streaming -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -force-streaming -enable-subreg-liveness -verify-machineinstrs < %s | FileCheck %s
 
 target triple="aarch64-linux-gnu"
 
@@ -26,18 +26,18 @@ define void @udot_multi_za32_u16_vg1x2(i32 %slice, <vscale x 16 x i8> %unused, <
 define void @udot_multi_za32_u16_vg1x4(i32 %slice, <vscale x 16 x i8> %unused, <vscale x 8 x i16> %zn0, <vscale x 8 x i16> %zn1, <vscale x 8 x i16> %zn2, <vscale x 8 x i16> %zn3,
 ; CHECK-LABEL: udot_multi_za32_u16_vg1x4:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z26.d, z7.d
-; CHECK-NEXT:    mov z31.d, z4.d
-; CHECK-NEXT:    mov w8, w0
 ; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    mov z26.d, z7.d
 ; CHECK-NEXT:    mov z25.d, z6.d
-; CHECK-NEXT:    mov z30.d, z3.d
+; CHECK-NEXT:    mov z7.d, z4.d
+; CHECK-NEXT:    mov w8, w0
 ; CHECK-NEXT:    mov z24.d, z5.d
-; CHECK-NEXT:    mov z29.d, z2.d
 ; CHECK-NEXT:    ld1h { z27.h }, p0/z, [x1]
-; CHECK-NEXT:    mov z28.d, z1.d
-; CHECK-NEXT:    udot za.s[w8, 0, vgx4], { z28.h - z31.h }, { z24.h - z27.h }
-; CHECK-NEXT:    udot za.s[w8, 7, vgx4], { z28.h - z31.h }, { z24.h - z27.h }
+; CHECK-NEXT:    mov z6.d, z3.d
+; CHECK-NEXT:    mov z5.d, z2.d
+; CHECK-NEXT:    mov z4.d, z1.d
+; CHECK-NEXT:    udot za.s[w8, 0, vgx4], { z4.h - z7.h }, { z24.h - z27.h }
+; CHECK-NEXT:    udot za.s[w8, 7, vgx4], { z4.h - z7.h }, { z24.h - z27.h }
 ; CHECK-NEXT:    ret
                                         <vscale x 8 x i16> %zn4, <vscale x 8 x i16> %zn5, <vscale x 8 x i16> %zn6, <vscale x 8 x i16> %zn7) #0 {
   call void @llvm.aarch64.sme.udot.za32.vg1x4.nxv8i16(i32 %slice, <vscale x 8 x i16> %zn0, <vscale x 8 x i16> %zn1, <vscale x 8 x i16> %zn2, <vscale x 8 x i16> %zn3,
@@ -68,18 +68,18 @@ define void @udot_multi_za32_u8_vg1x2(i32 %slice, <vscale x 16 x i8> %unused, <v
 define void @udot_multi_za32_u8_vg1x4(i32 %slice, <vscale x 16 x i8> %unused, <vscale x 16 x i8> %zn0, <vscale x 16 x i8> %zn1, <vscale x 16 x i8> %zn2, <vscale x 16 x i8> %zn3,
 ; CHECK-LABEL: udot_multi_za32_u8_vg1x4:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z26.d, z7.d
-; CHECK-NEXT:    mov z31.d, z4.d
-; CHECK-NEXT:    mov w8, w0
 ; CHECK-NEXT:    ptrue p0.b
+; CHECK-NEXT:    mov z26.d, z7.d
 ; CHECK-NEXT:    mov z25.d, z6.d
-; CHECK-NEXT:    mov z30.d, z3.d
+; CHECK-NEXT:    mov z7.d, z4.d
+; CHECK-NEXT:    mov w8, w0
 ; CHECK-NEXT:    mov z24.d, z5.d
-; CHECK-NEXT:    mov z29.d, z2.d
 ; CHECK-NEXT:    ld1b { z27.b }, p0/z, [x1]
-; CHECK-NEXT:    mov z28.d, z1.d
-; CHECK-NEXT:    udot za.s[w8, 0, vgx4], { z28.b - z31.b }, { z24.b - z27.b }
-; CHECK-NEXT:    udot za.s[w8, 7, vgx4], { z28.b - z31.b }, { z24.b - z27.b }
+; CHECK-NEXT:    mov z6.d, z3.d
+; CHECK-NEXT:    mov z5.d, z2.d
+; CHECK-NEXT:    mov z4.d, z1.d
+; CHECK-NEXT:    udot za.s[w8, 0, vgx4], { z4.b - z7.b }, { z24.b - z27.b }
+; CHECK-NEXT:    udot za.s[w8, 7, vgx4], { z4.b - z7.b }, { z24.b - z27.b }
 ; CHECK-NEXT:    ret
                                       <vscale x 16 x i8> %zn4, <vscale x 16 x i8> %zn5, <vscale x 16 x i8> %zn6, <vscale x 16 x i8> %zn7) #0 {
   call void @llvm.aarch64.sme.udot.za32.vg1x4.nxv16i8(i32 %slice, <vscale x 16 x i8> %zn0, <vscale x 16 x i8> %zn1, <vscale x 16 x i8> %zn2, <vscale x 16 x i8> %zn3,
@@ -110,18 +110,18 @@ define void @udot_multi_za64_u16_vg1x2(i32 %slice, <vscale x 16 x i8> %unused, <
 define void @udot_multi_za64_u16_vg1x4(i32 %slice, <vscale x 16 x i8> %unused, <vscale x 8 x i16> %zn0, <vscale x 8 x i16> %zn1, <vscale x 8 x i16> %zn2, <vscale x 8 x i16> %zn3,
 ; CHECK-LABEL: udot_multi_za64_u16_vg1x4:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z26.d, z7.d
-; CHECK-NEXT:    mov z31.d, z4.d
-; CHECK-NEXT:    mov w8, w0
 ; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    mov z26.d, z7.d
 ; CHECK-NEXT:    mov z25.d, z6.d
-; CHECK-NEXT:    mov z30.d, z3.d
+; CHECK-NEXT:    mov z7.d, z4.d
+; CHECK-NEXT:    mov w8, w0
 ; CHECK-NEXT:    mov z24.d, z5.d
-; CHECK-NEXT:    mov z29.d, z2.d
 ; CHECK-NEXT:    ld1h { z27.h }, p0/z, [x1]
-; CHECK-NEXT:    mov z28.d, z1.d
-; CHECK-NEXT:    udot za.d[w8, 0, vgx4], { z28.h - z31.h }, { z24.h - z27.h }
-; CHECK-NEXT:    udot za.d[w8, 7, vgx4], { z28.h - z31.h }, { z24.h - z27.h }
+; CHECK-NEXT:    mov z6.d, z3.d
+; CHECK-NEXT:    mov z5.d, z2.d
+; CHECK-NEXT:    mov z4.d, z1.d
+; CHECK-NEXT:    udot za.d[w8, 0, vgx4], { z4.h - z7.h }, { z24.h - z27.h }
+; CHECK-NEXT:    udot za.d[w8, 7, vgx4], { z4.h - z7.h }, { z24.h - z27.h }
 ; CHECK-NEXT:    ret
                                        <vscale x 8 x i16> %zn4, <vscale x 8 x i16> %zn5, <vscale x 8 x i16> %zn6, <vscale x 8 x i16> %zn7) #1 {
   call void @llvm.aarch64.sme.udot.za64.vg1x4.nxv8i16(i32 %slice, <vscale x 8 x i16> %zn0, <vscale x 8 x i16> %zn1, <vscale x 8 x i16> %zn2, <vscale x 8 x i16> %zn3,
@@ -152,18 +152,18 @@ define void @usdot_multi_za32_u8_vg1x2(i32 %slice, <vscale x 16 x i8> %unused, <
 define void @usdot_multi_za32_u8_vg1x4(i32 %slice, <vscale x 16 x i8> %unused, <vscale x 16 x i8> %zn0, <vscale x 16 x i8> %zn1, <vscale x 16 x i8> %zn2, <vscale x 16 x i8> %zn3,
 ; CHECK-LABEL: usdot_multi_za32_u8_vg1x4:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z26.d, z7.d
-; CHECK-NEXT:    mov z31.d, z4.d
-; CHECK-NEXT:    mov w8, w0
 ; CHECK-NEXT:    ptrue p0.b
+; CHECK-NEXT:    mov z26.d, z7.d
 ; CHECK-NEXT:    mov z25.d, z6.d
-; CHECK-NEXT:    mov z30.d, z3.d
+; CHECK-NEXT:    mov z7.d, z4.d
+; CHECK-NEXT:    mov w8, w0
 ; CHECK-NEXT:    mov z24.d, z5.d
-; CHECK-NEXT:    mov z29.d, z2.d
 ; CHECK-NEXT:    ld1b { z27.b }, p0/z, [x1]
-; CHECK-NEXT:    mov z28.d, z1.d
-; CHECK-NEXT:    usdot za.s[w8, 0, vgx4], { z28.b - z31.b }, { z24.b - z27.b }
-; CHECK-NEXT:    usdot za.s[w8, 7, vgx4], { z28.b - z31.b }, { z24.b - z27.b }
+; CHECK-NEXT:    mov z6.d, z3.d
+; CHECK-NEXT:    mov z5.d, z2.d
+; CHECK-NEXT:    mov z4.d, z1.d
+; CHECK-NEXT:    usdot za.s[w8, 0, vgx4], { z4.b - z7.b }, { z24.b - z27.b }
+; CHECK-NEXT:    usdot za.s[w8, 7, vgx4], { z4.b - z7.b }, { z24.b - z27.b }
 ; CHECK-NEXT:    ret
                                       <vscale x 16 x i8> %zn4, <vscale x 16 x i8> %zn5, <vscale x 16 x i8> %zn6, <vscale x 16 x i8> %zn7) #0 {
   call void @llvm.aarch64.sme.usdot.za32.vg1x4.nxv16i8(i32 %slice, <vscale x 16 x i8> %zn0, <vscale x 16 x i8> %zn1, <vscale x 16 x i8> %zn2, <vscale x 16 x i8> %zn3,
@@ -197,18 +197,18 @@ define void @sdot_multi_za32_u16_vg1x2(i32 %slice, <vscale x 16 x i8> %unused, <
 define void @sdot_multi_za32_u16_vg1x4(i32 %slice, <vscale x 16 x i8> %unused, <vscale x 8 x i16> %zn0, <vscale x 8 x i16> %zn1, <vscale x 8 x i16> %zn2, <vscale x 8 x i16> %zn3,
 ; CHECK-LABEL: sdot_multi_za32_u16_vg1x4:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z26.d, z7.d
-; CHECK-NEXT:    mov z31.d, z4.d
-; CHECK-NEXT:    mov w8, w0
 ; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    mov z26.d, z7.d
 ; CHECK-NEXT:    mov z25.d, z6.d
-; CHECK-NEXT:    mov z30.d, z3.d
+; CHECK-NEXT:    mov z7.d, z4.d
+; CHECK-NEXT:    mov w8, w0
 ; CHECK-NEXT:    mov z24.d, z5.d
-; CHECK-NEXT:    mov z29.d, z2.d
 ; CHECK-NEXT:    ld1h { z27.h }, p0/z, [x1]
-; CHECK-NEXT:    mov z28.d, z1.d
-; CHECK-NEXT:    sdot za.s[w8, 0, vgx4], { z28.h - z31.h }, { z24.h - z27.h }
-; CHECK-NEXT:    sdot za.s[w8, 7, vgx4], { z28.h - z31.h }, { z24.h - z27.h }
+; CHECK-NEXT:    mov z6.d, z3.d
+; CHECK-NEXT:    mov z5.d, z2.d
+; CHECK-NEXT:    mov z4.d, z1.d
+; CHECK-NEXT:    sdot za.s[w8, 0, vgx4], { z4.h - z7.h }, { z24.h - z27.h }
+; CHECK-NEXT:    sdot za.s[w8, 7, vgx4], { z4.h - z7.h }, { z24.h - z27.h }
 ; CHECK-NEXT:    ret
                                         <vscale x 8 x i16> %zn4, <vscale x 8 x i16> %zn5, <vscale x 8 x i16> %zn6, <vscale x 8 x i16> %zn7) #0 {
   call void @llvm.aarch64.sme.sdot.za32.vg1x4.nxv8i16(i32 %slice, <vscale x 8 x i16> %zn0, <vscale x 8 x i16> %zn1, <vscale x 8 x i16> %zn2, <vscale x 8 x i16> %zn3,
@@ -239,18 +239,18 @@ define void @sdot_multi_za32_u8_vg1x2(i32 %slice, <vscale x 16 x i8> %unused, <v
 define void @sdot_multi_za32_u8_vg1x4(i32 %slice, <vscale x 16 x i8> %unused, <vscale x 16 x i8> %zn0, <vscale x 16 x i8> %zn1, <vscale x 16 x i8> %zn2, <vscale x 16 x i8> %zn3,
 ; CHECK-LABEL: sdot_multi_za32_u8_vg1x4:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z26.d, z7.d
-; CHECK-NEXT:    mov z31.d, z4.d
-; CHECK-NEXT:    mov w8, w0
 ; CHECK-NEXT:    ptrue p0.b
+; CHECK-NEXT:    mov z26.d, z7.d
 ; CHECK-NEXT:    mov z25.d, z6.d
-; CHECK-NEXT:    mov z30.d, z3.d
+; CHECK-NEXT:    mov z7.d, z4.d
+; CHECK-NEXT:    mov w8, w0
 ; CHECK-NEXT:    mov z24.d, z5.d
-; CHECK-NEXT:    mov z29.d, z2.d
 ; CHECK-NEXT:    ld1b { z27.b }, p0/z, [x1]
-; CHECK-NEXT:    mov z28.d, z1.d
-; CHECK-NEXT:    sdot za.s[w8, 0, vgx4], { z28.b - z31.b }, { z24.b - z27.b }
-; CHECK-NEXT:    sdot za.s[w8, 7, vgx4], { z28.b - z31.b }, { z24.b - z27.b }
+; CHECK-NEXT:    mov z6.d, z3.d
+; CHECK-NEXT:    mov z5.d, z2.d
+; CHECK-NEXT:    mov z4.d, z1.d
+; CHECK-NEXT:    sdot za.s[w8,...
[truncated]

Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR doesn't make much sense without seeing the accompanying PR that improves the register allocation using these nodes. Could you either fold those changes into this PR, or stack another PR on top of this one?

@@ -1,5 +1,5 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -force-streaming -verify-machineinstrs < %s | FileCheck %s
; RUN: llc -force-streaming -enable-subreg-liveness -verify-machineinstrs < %s | FileCheck %s
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you enabling subreg-liveness here?

Comment on lines 5943 to 5958
case Intrinsic::aarch64_sme_uvdot_lane_za32_vg1x4:
case Intrinsic::aarch64_sme_suvdot_lane_za32_vg1x4:
case Intrinsic::aarch64_sme_usvdot_lane_za32_vg1x4:
case Intrinsic::aarch64_sme_svdot_lane_za32_vg1x4:
case Intrinsic::aarch64_sme_usdot_lane_za32_vg1x4:
case Intrinsic::aarch64_sme_udot_lane_za32_vg1x4:
case Intrinsic::aarch64_sme_sudot_lane_za32_vg1x4:
case Intrinsic::aarch64_sme_sdot_lane_za32_vg1x4:
return TryLowerMultiVecSMEDotIntrinsic(Op, DAG, 4);
case Intrinsic::aarch64_sme_uvdot_lane_za32_vg1x2:
case Intrinsic::aarch64_sme_sdot_lane_za32_vg1x2:
case Intrinsic::aarch64_sme_svdot_lane_za32_vg1x2:
case Intrinsic::aarch64_sme_usdot_lane_za32_vg1x2:
case Intrinsic::aarch64_sme_sudot_lane_za32_vg1x2:
case Intrinsic::aarch64_sme_udot_lane_za32_vg1x2:
return TryLowerMultiVecSMEDotIntrinsic(Op, DAG, 2);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a structurally simpler way to implement this (avoiding the need to do custom isel) is to change the patterns to always use the FORM_STRIDED_TUPLE_X.._PSEUDO instruction instead of REG_SEQUENCE.

For the multi-vector load case that you're trying to improve, the inputs to the tuple are always COPY nodes of the form:

%9:zpr = COPY %7.zsub0:zpr2stridedorcontiguous

There are cases where the RegisterCoalescer can make better decisions when using regular COPY nodes rather than the FORM_STRIDED_TUPLE pseudos. We could choose to handle the FORM_STRIDED_TUPLE pseudo with the hasPostISelHook = 1 where directly post-isel they are transformed into a REG_SEQUENCE node when any of the input values are not COPY nodes where the source register is in a 'stridedorcontiguous' register class. The REG_SEQUENCE node itself is then lowered later by the TwoAddressInstructionPass into individual COPY nodes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestion, @sdesmalen-arm. I was able to remove the FORM_STRIDED_TUPLE nodes and instead add hasPostISelHook = 1 to the pseudos, creating a REG_SEQUENCE if the input values are not copies from a StridedOrContiguous source register.
This has been added in a new commit, with the RegAllocHints commit added on top.

…des.

- Changed the tablegen patterns used by the dot intrinsics to always output
  the FORM_STRIDED_TUPLE_X#_PSEUDO nodes.
- Check that the operands to the pseudo are copies from a StridedOrContiguous
  register class in AdjustInstrPostInstrSelection, falling back on creating
  a REG_SEQUENCE node if not.
This patch implements getRegAllocationHints to improve register
allocation for the ZPR2Mul2Reg & ZPR4Mul4Reg classes.
If a FORM_STRIDED_TUPLE is found, getRegAllocationHints will try
to find a contiguous ZPRMulReg beginning with the same subregister
as the first operand of the pseudo.

For example, if the first strided load has been assigned $z16_z20_z24_z28
and the operands of the pseudo are each accessing subregister zsub2, the
correct register to use would be $z24_z25_z26_z27.
Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This remaining FORM_STRIDED_TUPLE_X2/4 code looks like it's probably unused:

Comment on lines 1147 to 1166
unsigned SubRegIdx = 0;
MCRegister FirstLoadPhysReg = VRM->getPhys(FirstLoadVirtReg);

// The subreg number is used to access the correct unit of the
// strided register found in the map above.
switch (MI.getOperand(1).getSubReg()) {
case AArch64::zsub0:
break;
case AArch64::zsub1:
SubRegIdx = 1;
break;
case AArch64::zsub2:
SubRegIdx = 2;
break;
case AArch64::zsub3:
SubRegIdx = 3;
break;
default:
continue;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to shorten this to

unsigned SubRegIdx = AArch64::zsub0 - MI.getOperand(1).getSubReg();
if (SubRegIdx > 3)
  continue;

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @MDevereau, I was able to shorten this to something similar:

  SubRegIdx = MI.getOperand(1).getSubReg() - AArch64::zsub0;
  if (SubRegIdx < 0 || SubRegIdx > 3)
    continue;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SubRegIdx is unsigned so SubRegIdx < 0 shouldn't be necessary I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the latest commit I changed SubRegIdx to a signed int. The reason was in case getSubReg() happens to return a register number which is smaller than AArch64::zsub0.

- Removed switch from getRegAllocationHints
Operand imm_ty, ComplexPattern tileslice>
: Pat<(intrinsic (i32 (tileslice MatrixIndexGPR32Op8_11:$base, index_ty:$offset)), vt:$Zn1, vt:$Zn2, vt:$Zm, (i32 imm_ty:$i)),
(!cast<Instruction>(name # _PSEUDO) $base, $offset,
(FORM_STRIDED_TUPLE_X2_PSEUDO vt:$Zn1,vt:$Zn2), zpr_ty:$Zm, imm_ty:$i)>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than creating new patterns, can we just update the existing ones? Then maybe other instructions (that use the same pattern class) could also benefit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a number of other intrinsics which use these patterns other than sdot/udot, which is why I initially added a new pattern. However, since we will fall back on REG_SEQUENCE anyway if the expected copy instructions are not found I think we can just update the existing one.

@@ -28,6 +28,7 @@ def AArch64_restore_zt : SDNode<"AArch64ISD::RESTORE_ZT", SDTypeProfile<0, 2,
def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2,
[SDTCisInt<0>, SDTCisPtrTy<1>]>,
[SDNPHasChain, SDNPSideEffect, SDNPMayStore]>;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove newline

Comment on lines 8695 to 8697
if (!MO.isReg())
continue;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be an assert? (the operands can only be regs, right?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is right, I've changed this to an assert.


MachineOperand *Def = MRI.getOneDef(MO.getReg());
if (!Def || !Def->isReg() || !Def->getParent()->isCopy())
continue;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than continue it's better to break instead, because if one of them is not a COPY then we don't need to process the other operands.

Comment on lines 8704 to 8705
if (!CpyOp.isReg())
continue;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is always the case for COPY? If so, this check can be removed.

Comment on lines 8716 to 8717
if (MRI.getRegClass(Ld->getReg()) == RegClass)
UseFormStrided = true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would help to have this loop in a separate function that returns false if there is any reason that the operands don't meet the criteria. Then you can also include the check for the subreg indices in the same loop (for which you currently require a std::equal).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved this to a new function called shouldUseFormStridedPseudo() & removed the std::equal test.

…sses

- Removed isReg() check for operand 1 of a copy
- Moved loop over FORM_STRIDED_TUPLE operands to new function.
Comment on lines 7642 to 7646
case ISD::INTRINSIC_W_CHAIN: {
unsigned IID = N->getConstantOperandVal(1);
if (IID < Intrinsic::num_intrinsics)
return IID;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be removed?

continue;

// Look up the physical register mapped to the first load of the pseudo.
Register FirstLoadVirtReg = MI.getOperand(1).getReg();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't name variables with the assumption that these result from certain operations, like Loads.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed this to FirstOpVirtReg.

Comment on lines 1152 to 1154
SubRegIdx = MI.getOperand(1).getSubReg() - AArch64::zsub0;
if (SubRegIdx < 0 || SubRegIdx > 3)
continue;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use a switch statement, such that explicitly only zsub0..zsub3 are supported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the changes suggested below to use getSubReg instead of iterating through MCRegUnits,SubRegIdx is no longer required. However, I've added the switch statement back in anyway to make sure only zsub0-zsub3 are supported.

Comment on lines 1133 to 1136
if (RegID != AArch64::ZPR2Mul2RegClassID &&
RegID != AArch64::ZPR4Mul4RegClassID)
return DefaultHints;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this condition can be removed (it would e.g. be equally valid for other register class with contiguous tuples, or perhaps just any register class?).

SmallVectorImpl<MCPhysReg> &Hints, const MachineFunction &MF,
const VirtRegMap *VRM, const LiveRegMatrix *Matrix) const {
const AArch64Subtarget &STI = MF.getSubtarget<AArch64Subtarget>();
const TargetRegisterInfo *TRI = STI.getRegisterInfo();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be removed because AArch64RegisterInfo this == TRI.

Comment on lines 1129 to 1130
bool DefaultHints =
TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints, MF, VRM);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should prioritise the tuples as added below over any generic hints. That means this function should be called last.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the TargetRegisterInfo::getRegAllocationHints call to the end of this function.

- Removed INTRINSIC_W_CHAIN from getIntrinsicID
- Renamed FORM_STRIDED_TUPLE_X#_PSEUDO -> FORM_TRANSPOSED_REG_TUPLE_X#_PSEUDO
- Add switch statement back into getRegAllocationHints
- Use getSubReg in getRegAllocationHints and remove index into RegUnits
Comment on lines 8720 to 8724
if (MI.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO ||
MI.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO) {
// If input values to the FORM_TRANSPOSED_REG_TUPLE pseudo aren't copies
// from a StridedOrContiguous class, fall back on REG_SEQUENCE node.
if (!shouldUseFormStridedPseudo(MI)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Maybe fold those conditions into one? or otherwise bail out early if shouldUseFormStridedPseudo(MI) == true, to avoid a level of indentation?

MI.getOpcode() != AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO)
continue;

switch (MI.getOperand(1).getSubReg()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please make a variable for MI.getOperand(1).getsubReg(), because it's used more than once.

for (MachineInstr &MI : MRI.def_instructions(VirtReg)) {
if (MI.getOpcode() != AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO &&
MI.getOpcode() != AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO)
continue;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I wonder if we just want to bail out early at this point?

@@ -34,6 +34,20 @@ def tileslicerange0s4 : ComplexPattern<i32, 2, "SelectSMETileSlice<0, 4>", []>;

def am_sme_indexed_b4 :ComplexPattern<iPTR, 2, "SelectAddrModeIndexedSVE<0,15>", [], [SDNPWantRoot]>;

def FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO :
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs a description of why we add these pseudos, and a comment that we expand them to REG_SEQUENCE with the post-isel hook if they don't meet certain criteria.

- Moved switch in shouldUseFormStridedPseudo outside of loop
- Added a description of the pseudo nodes to SMEInstrFormats.td
…_TRANSPOSED

  pseudo or the subreg indices are not in the range zsub0-zsub3.
Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor nits, but otherwise LGTM.

Could you also please update the PR's title and commit message, to reflect the new name before committing?


const TargetRegisterInfo *TRI =
MBB.getParent()->getSubtarget().getRegisterInfo();
for (unsigned i = 0; i < Size; i++) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
for (unsigned i = 0; i < Size; i++) {
for (unsigned I = 0; I < Size; ++I) {

Comment on lines 8686 to 8690
if (!CopySrcOp || !CopySrcOp->isReg() || OpSubReg != SubReg)
return false;

if (MRI.getRegClass(CopySrcOp->getReg()) != RegClass)
return false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
if (!CopySrcOp || !CopySrcOp->isReg() || OpSubReg != SubReg)
return false;
if (MRI.getRegClass(CopySrcOp->getReg()) != RegClass)
return false;
if (!CopySrcOp || !CopySrcOp->isReg() || OpSubReg != SubReg ||
MRI.getRegClass(CopySrcOp->getReg()) != RegClass)
return false;


for (unsigned I = 1; I < MI.getNumOperands(); ++I) {
MIB.add(MI.getOperand(I));
MIB.addImm(SubRegs[I - 1]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
MIB.addImm(SubRegs[I - 1]);
MIB.addImm(AArch64::zsub0 + (I-1));

Then you can remove SubRegs[].

@kmclaughlin-arm kmclaughlin-arm changed the title [AArch64][SME2] Add FORM_STRIDED_TUPLE pseudo nodes [AArch64][SME2] Improve register allocation of multi-vector SME intrinsics Dec 11, 2024
@kmclaughlin-arm kmclaughlin-arm merged commit 5ca26d7 into llvm:main Dec 12, 2024
8 checks passed
@kmclaughlin-arm kmclaughlin-arm deleted the sme2-form-strided-pseudo branch December 13, 2024 11:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants