Skip to content

[AArch64] Implement spill/fill of predicate pair register classes #76068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,15 @@ bool AArch64ExpandPseudo::expandSetTagLoop(
bool AArch64ExpandPseudo::expandSVESpillFill(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
unsigned Opc, unsigned N) {
assert((Opc == AArch64::LDR_ZXI || Opc == AArch64::STR_ZXI ||
Opc == AArch64::LDR_PXI || Opc == AArch64::STR_PXI) &&
"Unexpected opcode");
unsigned RState = (Opc == AArch64::LDR_ZXI || Opc == AArch64::LDR_PXI)
? RegState::Define
: 0;
unsigned sub0 = (Opc == AArch64::LDR_ZXI || Opc == AArch64::STR_ZXI)
? AArch64::zsub0
: AArch64::psub0;
const TargetRegisterInfo *TRI =
MBB.getParent()->getSubtarget().getRegisterInfo();
MachineInstr &MI = *MBBI;
Expand All @@ -756,9 +765,8 @@ bool AArch64ExpandPseudo::expandSVESpillFill(MachineBasicBlock &MBB,
assert(ImmOffset >= -256 && ImmOffset < 256 &&
"Immediate spill offset out of range");
BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(Opc))
.addReg(
TRI->getSubReg(MI.getOperand(0).getReg(), AArch64::zsub0 + Offset),
Opc == AArch64::LDR_ZXI ? RegState::Define : 0)
.addReg(TRI->getSubReg(MI.getOperand(0).getReg(), sub0 + Offset),
RState)
.addReg(MI.getOperand(1).getReg(), getKillRegState(Kill))
.addImm(ImmOffset);
}
Expand Down Expand Up @@ -1492,12 +1500,16 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
return expandSVESpillFill(MBB, MBBI, AArch64::STR_ZXI, 3);
case AArch64::STR_ZZXI:
return expandSVESpillFill(MBB, MBBI, AArch64::STR_ZXI, 2);
case AArch64::STR_PPXI:
return expandSVESpillFill(MBB, MBBI, AArch64::STR_PXI, 2);
case AArch64::LDR_ZZZZXI:
return expandSVESpillFill(MBB, MBBI, AArch64::LDR_ZXI, 4);
case AArch64::LDR_ZZZXI:
return expandSVESpillFill(MBB, MBBI, AArch64::LDR_ZXI, 3);
case AArch64::LDR_ZZXI:
return expandSVESpillFill(MBB, MBBI, AArch64::LDR_ZXI, 2);
case AArch64::LDR_PPXI:
return expandSVESpillFill(MBB, MBBI, AArch64::LDR_PXI, 2);
case AArch64::BLR_RVMARKER:
return expandCALL_RVMARKER(MBB, MBBI);
case AArch64::BLR_BTI:
Expand Down
15 changes: 15 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3771,6 +3771,13 @@ bool AArch64InstrInfo::getMemOpInfo(unsigned Opcode, TypeSize &Scale,
MinOffset = -256;
MaxOffset = 255;
break;
case AArch64::LDR_PPXI:
case AArch64::STR_PPXI:
Scale = TypeSize::getScalable(2);
Width = TypeSize::getScalable(2 * 2);
MinOffset = -256;
MaxOffset = 254;
break;
case AArch64::LDR_ZXI:
case AArch64::STR_ZXI:
Scale = TypeSize::getScalable(16);
Expand Down Expand Up @@ -4804,6 +4811,10 @@ void AArch64InstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
assert(SrcReg != AArch64::WSP);
} else if (AArch64::FPR32RegClass.hasSubClassEq(RC))
Opc = AArch64::STRSui;
else if (AArch64::PPR2RegClass.hasSubClassEq(RC)) {
Opc = AArch64::STR_PPXI;
StackID = TargetStackID::ScalableVector;
}
break;
case 8:
if (AArch64::GPR64allRegClass.hasSubClassEq(RC)) {
Expand Down Expand Up @@ -4980,6 +4991,10 @@ void AArch64InstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB,
assert(DestReg != AArch64::WSP);
} else if (AArch64::FPR32RegClass.hasSubClassEq(RC))
Opc = AArch64::LDRSui;
else if (AArch64::PPR2RegClass.hasSubClassEq(RC)) {
Opc = AArch64::LDR_PPXI;
StackID = TargetStackID::ScalableVector;
}
break;
case 8:
if (AArch64::GPR64allRegClass.hasSubClassEq(RC)) {
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -2378,11 +2378,13 @@ let Predicates = [HasSVEorSME] in {
def LDR_ZZXI : Pseudo<(outs ZZ_b_strided_and_contiguous:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
def LDR_ZZZXI : Pseudo<(outs ZZZ_b:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
def LDR_ZZZZXI : Pseudo<(outs ZZZZ_b_strided_and_contiguous:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
def LDR_PPXI : Pseudo<(outs PPR2:$pp), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simm4s1 seems weird given the underlying instructions use simm9 but then we don't target the pseudos via ISEL so I guess it doesn't really matter. It's also consistent with the existing pseudos, which also look wrong.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These offsets are checked when replacing frame index with actual addressing mode, so we know we are in range. I would suggest changing all these to simm9 in a separate patch, if only to not make someone wonder why did we use such strange constraints.

}
let mayStore = 1, hasSideEffects = 0 in {
def STR_ZZXI : Pseudo<(outs), (ins ZZ_b_strided_and_contiguous:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
def STR_ZZZXI : Pseudo<(outs), (ins ZZZ_b:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
def STR_ZZZZXI : Pseudo<(outs), (ins ZZZZ_b_strided_and_contiguous:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
def STR_PPXI : Pseudo<(outs), (ins PPR2:$pp, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
}

let AddedComplexity = 1 in {
Expand Down
92 changes: 92 additions & 0 deletions llvm/test/CodeGen/AArch64/spillfill-sve.mir
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
target triple = "aarch64--linux-gnu"

define aarch64_sve_vector_pcs void @spills_fills_stack_id_ppr() #0 { entry: unreachable }
define aarch64_sve_vector_pcs void @spills_fills_stack_id_ppr2() #0 { entry: unreachable }
define aarch64_sve_vector_pcs void @spills_fills_stack_id_ppr2mul2() #0 { entry: unreachable }
define aarch64_sve_vector_pcs void @spills_fills_stack_id_pnr() #1 { entry: unreachable }
define aarch64_sve_vector_pcs void @spills_fills_stack_id_virtreg_pnr() #1 { entry: unreachable }
define aarch64_sve_vector_pcs void @spills_fills_stack_id_zpr() #0 { entry: unreachable }
Expand Down Expand Up @@ -64,6 +66,96 @@ body: |
RET_ReallyLR
...
---
name: spills_fills_stack_id_ppr2
tracksRegLiveness: true
registers:
- { id: 0, class: ppr2 }
stack:
liveins:
- { reg: '$p0_p1', virtual-reg: '%0' }
body: |
bb.0.entry:
liveins: $p0_p1

; CHECK-LABEL: name: spills_fills_stack_id_ppr2
; CHECK: stack:
; CHECK: - { id: 0, name: '', type: spill-slot, offset: 0, size: 4, alignment: 2
; CHECK-NEXT: stack-id: scalable-vector, callee-saved-register: ''

; EXPAND-LABEL: name: spills_fills_stack_id_ppr2
; EXPAND: STR_PXI $p0, $sp, 6
; EXPAND: STR_PXI $p1, $sp, 7
; EXPAND: $p0 = LDR_PXI $sp, 6
; EXPAND: $p1 = LDR_PXI $sp, 7

%0:ppr2 = COPY $p0_p1

$p0 = IMPLICIT_DEF
$p1 = IMPLICIT_DEF
$p2 = IMPLICIT_DEF
$p3 = IMPLICIT_DEF
$p4 = IMPLICIT_DEF
$p5 = IMPLICIT_DEF
$p6 = IMPLICIT_DEF
$p7 = IMPLICIT_DEF
$p8 = IMPLICIT_DEF
$p9 = IMPLICIT_DEF
$p10 = IMPLICIT_DEF
$p11 = IMPLICIT_DEF
$p12 = IMPLICIT_DEF
$p13 = IMPLICIT_DEF
$p14 = IMPLICIT_DEF
$p15 = IMPLICIT_DEF

$p0_p1 = COPY %0
RET_ReallyLR
...
---
name: spills_fills_stack_id_ppr2mul2
tracksRegLiveness: true
registers:
- { id: 0, class: ppr2mul2 }
stack:
liveins:
- { reg: '$p0_p1', virtual-reg: '%0' }
body: |
bb.0.entry:
liveins: $p0_p1

; CHECK-LABEL: name: spills_fills_stack_id_ppr2
; CHECK: stack:
; CHECK: - { id: 0, name: '', type: spill-slot, offset: 0, size: 4, alignment: 2
; CHECK-NEXT: stack-id: scalable-vector, callee-saved-register: ''

; EXPAND-LABEL: name: spills_fills_stack_id_ppr2mul2
; EXPAND: STR_PXI $p0, $sp, 6
; EXPAND: STR_PXI $p1, $sp, 7
; EXPAND: $p0 = LDR_PXI $sp, 6
; EXPAND: $p1 = LDR_PXI $sp, 7

%0:ppr2mul2 = COPY $p0_p1

$p0 = IMPLICIT_DEF
$p1 = IMPLICIT_DEF
$p2 = IMPLICIT_DEF
$p3 = IMPLICIT_DEF
$p4 = IMPLICIT_DEF
$p5 = IMPLICIT_DEF
$p6 = IMPLICIT_DEF
$p7 = IMPLICIT_DEF
$p8 = IMPLICIT_DEF
$p9 = IMPLICIT_DEF
$p10 = IMPLICIT_DEF
$p11 = IMPLICIT_DEF
$p12 = IMPLICIT_DEF
$p13 = IMPLICIT_DEF
$p14 = IMPLICIT_DEF
$p15 = IMPLICIT_DEF

$p0_p1 = COPY %0
RET_ReallyLR
...
---
name: spills_fills_stack_id_pnr
tracksRegLiveness: true
registers:
Expand Down
67 changes: 67 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-pred-pair-spill-fill.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
; RUN: llc < %s | FileCheck %s

; Derived from
; #include <arm_sve.h>

; void g();

; svboolx2_t f0(int64_t i, int64_t n) {
; svboolx2_t r = svwhilelt_b16_x2(i, n);
; g();
; return r;
; }

; svboolx2_t f1(svcount_t n) {
; svboolx2_t r = svpext_lane_c8_x2(n, 1);
; g();
; return r;
; }
;
; Check that predicate register pairs are spilled/filled without an ICE in the backend.

target triple = "aarch64-unknown-linux"

define <vscale x 32 x i1> @f0(i64 %i, i64 %n) #0 {
entry:
%0 = tail call { <vscale x 8 x i1>, <vscale x 8 x i1> } @llvm.aarch64.sve.whilelt.x2.nxv8i1(i64 %i, i64 %n)
%1 = extractvalue { <vscale x 8 x i1>, <vscale x 8 x i1> } %0, 0
%2 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %1)
%3 = tail call <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1> poison, <vscale x 16 x i1> %2, i64 0)
%4 = extractvalue { <vscale x 8 x i1>, <vscale x 8 x i1> } %0, 1
%5 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %4)
%6 = tail call <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1> %3, <vscale x 16 x i1> %5, i64 16)
tail call void @g()
ret <vscale x 32 x i1> %6
}
; CHECK-LABEL: f0:
; CHECK: whilelt { p0.h, p1.h }
; CHECK: str p0, [sp, #6, mul vl]
; CHECK: str p1, [sp, #7, mul vl]
; CHECK: ldr p0, [sp, #6, mul vl]
; CHECK: ldr p1, [sp, #7, mul vl]

define <vscale x 32 x i1> @f1(target("aarch64.svcount") %n) #0 {
entry:
%0 = tail call { <vscale x 16 x i1>, <vscale x 16 x i1> } @llvm.aarch64.sve.pext.x2.nxv16i1(target("aarch64.svcount") %n, i32 1)
%1 = extractvalue { <vscale x 16 x i1>, <vscale x 16 x i1> } %0, 0
%2 = tail call <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1> poison, <vscale x 16 x i1> %1, i64 0)
%3 = extractvalue { <vscale x 16 x i1>, <vscale x 16 x i1> } %0, 1
%4 = tail call <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1> %2, <vscale x 16 x i1> %3, i64 16)
tail call void @g()
ret <vscale x 32 x i1> %4
}

; CHECK-LABEL: f1:
; CHECK: pext { p0.b, p1.b }
; CHECK: str p0, [sp, #6, mul vl]
; CHECK: str p1, [sp, #7, mul vl]
; CHECK: ldr p0, [sp, #6, mul vl]
; CHECK: ldr p1, [sp, #7, mul vl]

declare void @g(...)
declare { <vscale x 8 x i1>, <vscale x 8 x i1> } @llvm.aarch64.sve.whilelt.x2.nxv8i1(i64, i64)
declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1>)
declare <vscale x 32 x i1> @llvm.vector.insert.nxv32i1.nxv16i1(<vscale x 32 x i1>, <vscale x 16 x i1>, i64 immarg)
declare { <vscale x 16 x i1>, <vscale x 16 x i1> } @llvm.aarch64.sve.pext.x2.nxv16i1(target("aarch64.svcount"), i32 immarg) #1

attributes #0 = { nounwind "target-features"="+sve,+sve2,+sve2p1" }