Skip to content

[AArch64] Implement spill/fill of predicate pair register classes #76068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 22, 2023

Conversation

momchil-velikov
Copy link
Collaborator

@momchil-velikov momchil-velikov commented Dec 20, 2023

We are getting ICE with, e.g.

#include <arm_sve.h>

 void g();
 svboolx2_t f0(int64_t i, int64_t n) {
     svboolx2_t r = svwhilelt_b16_x2(i, n);
     g();
     return r;
 }

@llvmbot
Copy link
Member

llvmbot commented Dec 20, 2023

@llvm/pr-subscribers-backend-aarch64

Author: Momchil Velikov (momchil-velikov)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/76068.diff

5 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp (+15-3)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.cpp (+17)
  • (modified) llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td (+2)
  • (modified) llvm/test/CodeGen/AArch64/spillfill-sve.mir (+92)
  • (added) llvm/test/CodeGen/AArch64/sve-pred-pair-spill-fill.ll (+67)
diff --git a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
index 757471d6a905e1..bb7f4d907ffd7f 100644
--- a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
@@ -747,6 +747,15 @@ bool AArch64ExpandPseudo::expandSetTagLoop(
 bool AArch64ExpandPseudo::expandSVESpillFill(MachineBasicBlock &MBB,
                                              MachineBasicBlock::iterator MBBI,
                                              unsigned Opc, unsigned N) {
+  assert((Opc == AArch64::LDR_ZXI || Opc == AArch64::STR_ZXI ||
+          Opc == AArch64::LDR_PXI || Opc == AArch64::STR_PXI) &&
+         "Unexpected opcode");
+  unsigned RState = (Opc == AArch64::LDR_ZXI || Opc == AArch64::LDR_PXI)
+                        ? RegState::Define
+                        : 0;
+  unsigned sub0 = (Opc == AArch64::LDR_ZXI || Opc == AArch64::STR_ZXI)
+                      ? AArch64::zsub0
+                      : AArch64::psub0;
   const TargetRegisterInfo *TRI =
       MBB.getParent()->getSubtarget().getRegisterInfo();
   MachineInstr &MI = *MBBI;
@@ -756,9 +765,8 @@ bool AArch64ExpandPseudo::expandSVESpillFill(MachineBasicBlock &MBB,
     assert(ImmOffset >= -256 && ImmOffset < 256 &&
            "Immediate spill offset out of range");
     BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(Opc))
-        .addReg(
-            TRI->getSubReg(MI.getOperand(0).getReg(), AArch64::zsub0 + Offset),
-            Opc == AArch64::LDR_ZXI ? RegState::Define : 0)
+        .addReg(TRI->getSubReg(MI.getOperand(0).getReg(), sub0 + Offset),
+                RState)
         .addReg(MI.getOperand(1).getReg(), getKillRegState(Kill))
         .addImm(ImmOffset);
   }
@@ -1492,12 +1500,16 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
      return expandSVESpillFill(MBB, MBBI, AArch64::STR_ZXI, 3);
    case AArch64::STR_ZZXI:
      return expandSVESpillFill(MBB, MBBI, AArch64::STR_ZXI, 2);
+   case AArch64::STR_PPXI:
+     return expandSVESpillFill(MBB, MBBI, AArch64::STR_PXI, 2);
    case AArch64::LDR_ZZZZXI:
      return expandSVESpillFill(MBB, MBBI, AArch64::LDR_ZXI, 4);
    case AArch64::LDR_ZZZXI:
      return expandSVESpillFill(MBB, MBBI, AArch64::LDR_ZXI, 3);
    case AArch64::LDR_ZZXI:
      return expandSVESpillFill(MBB, MBBI, AArch64::LDR_ZXI, 2);
+   case AArch64::LDR_PPXI:
+     return expandSVESpillFill(MBB, MBBI, AArch64::LDR_PXI, 2);
    case AArch64::BLR_RVMARKER:
      return expandCALL_RVMARKER(MBB, MBBI);
    case AArch64::BLR_BTI:
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 7d71c316bcb0a2..44a22a6f7ec0e3 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -2197,6 +2197,7 @@ unsigned AArch64InstrInfo::isLoadFromStackSlot(const MachineInstr &MI,
   case AArch64::LDRDui:
   case AArch64::LDRQui:
   case AArch64::LDR_PXI:
+  case AArch64::LDR_PPXI:
     if (MI.getOperand(0).getSubReg() == 0 && MI.getOperand(1).isFI() &&
         MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 0) {
       FrameIndex = MI.getOperand(1).getIndex();
@@ -2221,6 +2222,7 @@ unsigned AArch64InstrInfo::isStoreToStackSlot(const MachineInstr &MI,
   case AArch64::STRDui:
   case AArch64::STRQui:
   case AArch64::STR_PXI:
+  case AArch64::STR_PPXI:
     if (MI.getOperand(0).getSubReg() == 0 && MI.getOperand(1).isFI() &&
         MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 0) {
       FrameIndex = MI.getOperand(1).getIndex();
@@ -3771,6 +3773,13 @@ bool AArch64InstrInfo::getMemOpInfo(unsigned Opcode, TypeSize &Scale,
     MinOffset = -256;
     MaxOffset = 255;
     break;
+  case AArch64::LDR_PPXI:
+  case AArch64::STR_PPXI:
+    Scale = TypeSize::getScalable(2);
+    Width = TypeSize::getScalable(2 * 2);
+    MinOffset = -256;
+    MaxOffset = 255;
+    break;
   case AArch64::LDR_ZXI:
   case AArch64::STR_ZXI:
     Scale = TypeSize::getScalable(16);
@@ -4804,6 +4813,10 @@ void AArch64InstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
         assert(SrcReg != AArch64::WSP);
     } else if (AArch64::FPR32RegClass.hasSubClassEq(RC))
       Opc = AArch64::STRSui;
+    else if (AArch64::PPR2RegClass.hasSubClassEq(RC)) {
+      Opc = AArch64::STR_PPXI;
+      StackID = TargetStackID::ScalableVector;
+    }
     break;
   case 8:
     if (AArch64::GPR64allRegClass.hasSubClassEq(RC)) {
@@ -4980,6 +4993,10 @@ void AArch64InstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB,
         assert(DestReg != AArch64::WSP);
     } else if (AArch64::FPR32RegClass.hasSubClassEq(RC))
       Opc = AArch64::LDRSui;
+    else if (AArch64::PPR2RegClass.hasSubClassEq(RC)) {
+      Opc = AArch64::LDR_PPXI;
+      StackID = TargetStackID::ScalableVector;
+    }
     break;
   case 8:
     if (AArch64::GPR64allRegClass.hasSubClassEq(RC)) {
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index f68059889d0c51..d496bf50e62d10 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -2378,11 +2378,13 @@ let Predicates = [HasSVEorSME] in {
     def LDR_ZZXI   : Pseudo<(outs   ZZ_b_strided_and_contiguous:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
     def LDR_ZZZXI  : Pseudo<(outs  ZZZ_b:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
     def LDR_ZZZZXI : Pseudo<(outs ZZZZ_b_strided_and_contiguous:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
+    def LDR_PPXI   : Pseudo<(outs PPR2:$pp), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
   }
   let mayStore = 1, hasSideEffects = 0 in {
     def STR_ZZXI   : Pseudo<(outs), (ins   ZZ_b_strided_and_contiguous:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
     def STR_ZZZXI  : Pseudo<(outs), (ins  ZZZ_b:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
     def STR_ZZZZXI : Pseudo<(outs), (ins ZZZZ_b_strided_and_contiguous:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
+    def STR_PPXI   : Pseudo<(outs), (ins PPR2:$pp, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
   }
 
   let AddedComplexity = 1 in {
diff --git a/llvm/test/CodeGen/AArch64/spillfill-sve.mir b/llvm/test/CodeGen/AArch64/spillfill-sve.mir
index 01756b84600192..ef7d55a1c2395f 100644
--- a/llvm/test/CodeGen/AArch64/spillfill-sve.mir
+++ b/llvm/test/CodeGen/AArch64/spillfill-sve.mir
@@ -7,6 +7,8 @@
   target triple = "aarch64--linux-gnu"
 
   define aarch64_sve_vector_pcs void @spills_fills_stack_id_ppr() #0 { entry: unreachable }
+  define aarch64_sve_vector_pcs void @spills_fills_stack_id_ppr2() #0 { entry: unreachable }
+  define aarch64_sve_vector_pcs void @spills_fills_stack_id_ppr2mul2() #0 { entry: unreachable }
   define aarch64_sve_vector_pcs void @spills_fills_stack_id_pnr() #1 { entry: unreachable }
   define aarch64_sve_vector_pcs void @spills_fills_stack_id_virtreg_pnr() #1 { entry: unreachable }
   define aarch64_sve_vector_pcs void @spills_fills_stack_id_zpr() #0 { entry: unreachable }
@@ -64,6 +66,96 @@ body:             |
     RET_ReallyLR
 ...
 ---
+name: spills_fills_stack_id_ppr2
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: ppr2 }
+stack:
+liveins:
+  - { reg: '$p0_p1', virtual-reg: '%0' }
+body:             |
+  bb.0.entry:
+    liveins: $p0_p1
+
+    ; CHECK-LABEL: name: spills_fills_stack_id_ppr2
+    ; CHECK: stack:
+    ; CHECK:      - { id: 0, name: '', type: spill-slot, offset: 0, size: 4, alignment: 2
+    ; CHECK-NEXT:     stack-id: scalable-vector, callee-saved-register: ''
+
+    ; EXPAND-LABEL: name: spills_fills_stack_id_ppr2
+    ; EXPAND: STR_PXI $p0, $sp, 6
+    ; EXPAND: STR_PXI $p1, $sp, 7
+    ; EXPAND: $p0 = LDR_PXI $sp, 6
+    ; EXPAND: $p1 = LDR_PXI $sp, 7
+
+    %0:ppr2 = COPY $p0_p1
+
+    $p0 = IMPLICIT_DEF
+    $p1 = IMPLICIT_DEF
+    $p2 = IMPLICIT_DEF
+    $p3 = IMPLICIT_DEF
+    $p4 = IMPLICIT_DEF
+    $p5 = IMPLICIT_DEF
+    $p6 = IMPLICIT_DEF
+    $p7 = IMPLICIT_DEF
+    $p8 = IMPLICIT_DEF
+    $p9 = IMPLICIT_DEF
+    $p10 = IMPLICIT_DEF
+    $p11 = IMPLICIT_DEF
+    $p12 = IMPLICIT_DEF
+    $p13 = IMPLICIT_DEF
+    $p14 = IMPLICIT_DEF
+    $p15 = IMPLICIT_DEF
+
+    $p0_p1 = COPY %0
+    RET_ReallyLR
+...
+---
+name: spills_fills_stack_id_ppr2mul2
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: ppr2mul2 }
+stack:
+liveins:
+  - { reg: '$p0_p1', virtual-reg: '%0' }
+body:             |
+  bb.0.entry:
+    liveins: $p0_p1
+
+    ; CHECK-LABEL: name: spills_fills_stack_id_ppr2
+    ; CHECK: stack:
+    ; CHECK:      - { id: 0, name: '', type: spill-slot, offset: 0, size: 4, alignment: 2
+    ; CHECK-NEXT:     stack-id: scalable-vector, callee-saved-register: ''
+
+    ; EXPAND-LABEL: name: spills_fills_stack_id_ppr2mul2
+    ; EXPAND: STR_PXI $p0, $sp, 6
+    ; EXPAND: STR_PXI $p1, $sp, 7
+    ; EXPAND: $p0 = LDR_PXI $sp, 6
+    ; EXPAND: $p1 = LDR_PXI $sp, 7
+
+    %0:ppr2mul2 = COPY $p0_p1
+
+    $p0 = IMPLICIT_DEF
+    $p1 = IMPLICIT_DEF
+    $p2 = IMPLICIT_DEF
+    $p3 = IMPLICIT_DEF
+    $p4 = IMPLICIT_DEF
+    $p5 = IMPLICIT_DEF
+    $p6 = IMPLICIT_DEF
+    $p7 = IMPLICIT_DEF
+    $p8 = IMPLICIT_DEF
+    $p9 = IMPLICIT_DEF
+    $p10 = IMPLICIT_DEF
+    $p11 = IMPLICIT_DEF
+    $p12 = IMPLICIT_DEF
+    $p13 = IMPLICIT_DEF
+    $p14 = IMPLICIT_DEF
+    $p15 = IMPLICIT_DEF
+
+    $p0_p1 = COPY %0
+    RET_ReallyLR
+...
+---
 name: spills_fills_stack_id_pnr
 tracksRegLiveness: true
 registers:
diff --git a/llvm/test/CodeGen/AArch64/sve-pred-pair-spill-fill.ll b/llvm/test/CodeGen/AArch64/sve-pred-pair-spill-fill.ll
new file mode 100644
index 00000000000000..eb7950ef10c9c4
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-pred-pair-spill-fill.ll
@@ -0,0 +1,67 @@
+; RUN: llc < %s | FileCheck %s
+
+; Derived from 
+; #include <arm_sve.h>
+
+; void g();
+
+; svboolx2_t f0(int64_t i, int64_t n) {
+;     svboolx2_t r = svwhilelt_b16_x2(i, n);
+;     g();
+;     return r;
+; }
+
+; svboolx2_t f1(svcount_t n) {
+;     svboolx2_t r = svpext_lane_c8_x2(n, 1);
+;     g();
+;     return r;
+; }
+; 
+; Check that predicate register pairs are spilled/filled without an ICE in the backend.
+
+target triple = "aarch64-unknown-linux"
+
+define <vscale x 32 x i1> @f0(i64 %i, i64 %n) #0 {
+entry:
+  %0 = tail call { <vscale x 8 x i1>, <vscale x 8 x i1> } @llvm.aarch64.sve.whilelt.x2.nxv8i1(i64 %i, i64 %n)
+  %1 = extractvalue { <vscale x 8 x i1>, <vscale x 8 x i1> } %0, 0
+  %2 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %1)
+  %3 = tail call <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1> poison, <vscale x 16 x i1> %2, i64 0)
+  %4 = extractvalue { <vscale x 8 x i1>, <vscale x 8 x i1> } %0, 1
+  %5 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %4)
+  %6 = tail call <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1> %3, <vscale x 16 x i1> %5, i64 16)
+  tail call void @g() #4
+  ret <vscale x 32 x i1> %6
+}
+; CHECK-LABEL: f0:
+; CHECK: whilelt { p0.h, p1.h }
+; CHECK: str p0, [sp, #6, mul vl]
+; CHECK: str p1, [sp, #7, mul vl]
+; CHECK: ldr p0, [sp, #6, mul vl]
+; CHECK: ldr p1, [sp, #7, mul vl]
+
+define <vscale x 32 x i1> @f1(target("aarch64.svcount") %n) #0 {
+entry:
+  %0 = tail call { <vscale x 16 x i1>, <vscale x 16 x i1> } @llvm.aarch64.sve.pext.x2.nxv16i1(target("aarch64.svcount") %n, i32 1)
+  %1 = extractvalue { <vscale x 16 x i1>, <vscale x 16 x i1> } %0, 0
+  %2 = tail call <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1> poison, <vscale x 16 x i1> %1, i64 0)
+  %3 = extractvalue { <vscale x 16 x i1>, <vscale x 16 x i1> } %0, 1
+  %4 = tail call <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1> %2, <vscale x 16 x i1> %3, i64 16)
+  tail call void @g() #4
+  ret <vscale x 32 x i1> %4
+}
+
+; CHECK-LABEL: f1:
+; CHECK: pext { p0.b, p1.b }
+; CHECK: str p0, [sp, #6, mul vl]
+; CHECK: str p1, [sp, #7, mul vl]
+; CHECK: ldr p0, [sp, #6, mul vl]
+; CHECK: ldr p1, [sp, #7, mul vl]
+
+declare void @g(...)
+declare { <vscale x 8 x i1>, <vscale x 8 x i1> } @llvm.aarch64.sve.whilelt.x2.nxv8i1(i64, i64)
+declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1>)
+declare <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1>, <vscale x 16 x i1>, i64 immarg)
+declare { <vscale x 16 x i1>, <vscale x 16 x i1> } @llvm.aarch64.sve.pext.x2.nxv16i1(target("aarch64.svcount"), i32 immarg) #1
+
+attributes #0 = { nounwind "target-features"="+sve,+sve2,+sve2p1" }

@@ -2378,11 +2378,13 @@ let Predicates = [HasSVEorSME] in {
def LDR_ZZXI : Pseudo<(outs ZZ_b_strided_and_contiguous:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
def LDR_ZZZXI : Pseudo<(outs ZZZ_b:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
def LDR_ZZZZXI : Pseudo<(outs ZZZZ_b_strided_and_contiguous:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
def LDR_PPXI : Pseudo<(outs PPR2:$pp), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simm4s1 seems weird given the underlying instructions use simm9 but then we don't target the pseudos via ISEL so I guess it doesn't really matter. It's also consistent with the existing pseudos, which also look wrong.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These offsets are checked when replacing frame index with actual addressing mode, so we know we are in range. I would suggest changing all these to simm9 in a separate patch, if only to not make someone wonder why did we use such strange constraints.

@momchil-velikov momchil-velikov merged commit 4b69689 into llvm:main Dec 22, 2023
@momchil-velikov momchil-velikov deleted the pred-spill branch January 30, 2024 10:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants