Skip to content

Commit 44567f6

Browse files
committed
[AArch64][GISel] Translate legal SVE formal arguments and select COPY for SVE
1 parent 0851d7b commit 44567f6

File tree

6 files changed

+456
-12
lines changed

6 files changed

+456
-12
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ static cl::opt<unsigned> MaxXors("aarch64-max-xors", cl::init(16), cl::Hidden,
149149
// scalable vector types for all instruction, even if SVE is not yet supported
150150
// with some instructions.
151151
// See [AArch64TargetLowering::fallbackToDAGISel] for implementation details.
152-
static cl::opt<bool> EnableSVEGISel(
152+
cl::opt<bool> EnableSVEGISel(
153153
"aarch64-enable-gisel-sve", cl::Hidden,
154154
cl::desc("Enable / disable SVE scalable vectors in Global ISel"),
155155
cl::init(false));

llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
using namespace llvm;
5454
using namespace AArch64GISelUtils;
5555

56+
extern cl::opt<bool> EnableSVEGISel;
57+
5658
AArch64CallLowering::AArch64CallLowering(const AArch64TargetLowering &TLI)
5759
: CallLowering(&TLI) {}
5860

@@ -525,10 +527,10 @@ static void handleMustTailForwardedRegisters(MachineIRBuilder &MIRBuilder,
525527

526528
bool AArch64CallLowering::fallBackToDAGISel(const MachineFunction &MF) const {
527529
auto &F = MF.getFunction();
528-
if (F.getReturnType()->isScalableTy() ||
530+
if (!EnableSVEGISel && (F.getReturnType()->isScalableTy() ||
529531
llvm::any_of(F.args(), [](const Argument &A) {
530532
return A.getType()->isScalableTy();
531-
}))
533+
})))
532534
return true;
533535
const auto &ST = MF.getSubtarget<AArch64Subtarget>();
534536
if (!ST.hasNEON() || !ST.hasFPARMv8()) {

llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -597,8 +597,14 @@ getRegClassForTypeOnBank(LLT Ty, const RegisterBank &RB,
597597
/// Given a register bank, and size in bits, return the smallest register class
598598
/// that can represent that combination.
599599
static const TargetRegisterClass *
600-
getMinClassForRegBank(const RegisterBank &RB, unsigned SizeInBits,
600+
getMinClassForRegBank(const RegisterBank &RB, TypeSize SizeInBits,
601601
bool GetAllRegSet = false) {
602+
if (SizeInBits.isScalable()) {
603+
assert(RB.getID() == AArch64::FPRRegBankID
604+
&& "Expected FPR regbank for scalable type size");
605+
return &AArch64::ZPRRegClass;
606+
}
607+
602608
unsigned RegBankID = RB.getID();
603609

604610
if (RegBankID == AArch64::GPRRegBankID) {
@@ -939,8 +945,9 @@ getRegClassesForCopy(MachineInstr &I, const TargetInstrInfo &TII,
939945
Register SrcReg = I.getOperand(1).getReg();
940946
const RegisterBank &DstRegBank = *RBI.getRegBank(DstReg, MRI, TRI);
941947
const RegisterBank &SrcRegBank = *RBI.getRegBank(SrcReg, MRI, TRI);
942-
unsigned DstSize = RBI.getSizeInBits(DstReg, MRI, TRI);
943-
unsigned SrcSize = RBI.getSizeInBits(SrcReg, MRI, TRI);
948+
949+
TypeSize DstSize = RBI.getSizeInBits(DstReg, MRI, TRI);
950+
TypeSize SrcSize = RBI.getSizeInBits(SrcReg, MRI, TRI);
944951

945952
// Special casing for cross-bank copies of s1s. We can technically represent
946953
// a 1-bit value with any size of register. The minimum size for a GPR is 32
@@ -951,7 +958,7 @@ getRegClassesForCopy(MachineInstr &I, const TargetInstrInfo &TII,
951958
// register bank. Or make a new helper that carries along some constraint
952959
// information.
953960
if (SrcRegBank != DstRegBank && (DstSize == 1 && SrcSize == 1))
954-
SrcSize = DstSize = 32;
961+
SrcSize = DstSize = TypeSize::getFixed(32);
955962

956963
return {getMinClassForRegBank(SrcRegBank, SrcSize, true),
957964
getMinClassForRegBank(DstRegBank, DstSize, true)};
@@ -1016,8 +1023,8 @@ static bool selectCopy(MachineInstr &I, const TargetInstrInfo &TII,
10161023
return false;
10171024
}
10181025

1019-
unsigned SrcSize = TRI.getRegSizeInBits(*SrcRC);
1020-
unsigned DstSize = TRI.getRegSizeInBits(*DstRC);
1026+
const TypeSize SrcSize = TRI.getRegSizeInBits(*SrcRC);
1027+
const TypeSize DstSize = TRI.getRegSizeInBits(*DstRC);
10211028
unsigned SubReg;
10221029

10231030
// If the source bank doesn't support a subregister copy small enough,

llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ AArch64RegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
258258
case AArch64::QQQRegClassID:
259259
case AArch64::QQQQRegClassID:
260260
case AArch64::ZPRRegClassID:
261+
case AArch64::ZPR_3bRegClassID:
261262
return getRegBank(AArch64::FPRRegBankID);
262263
case AArch64::GPR32commonRegClassID:
263264
case AArch64::GPR32RegClassID:
@@ -714,10 +715,10 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
714715
// If both RB are null that means both registers are generic.
715716
// We shouldn't be here.
716717
assert(DstRB && SrcRB && "Both RegBank were nullptr");
717-
unsigned Size = getSizeInBits(DstReg, MRI, TRI);
718+
TypeSize Size = getSizeInBits(DstReg, MRI, TRI);
718719
return getInstructionMapping(
719-
DefaultMappingID, copyCost(*DstRB, *SrcRB, TypeSize::getFixed(Size)),
720-
getCopyMapping(DstRB->getID(), SrcRB->getID(), Size),
720+
DefaultMappingID, copyCost(*DstRB, *SrcRB, Size),
721+
getCopyMapping(DstRB->getID(), SrcRB->getID(), Size.getKnownMinValue()),
721722
// We only care about the mapping of the destination.
722723
/*NumOperands*/ 1);
723724
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mtriple=aarch64-linux-gnu -O0 -mattr=+sve -global-isel -global-isel-abort=1 -aarch64-enable-gisel-sve=1 %s -o - | FileCheck %s
3+
;; vscale x 128-bit
4+
5+
define void @formal_argument_nxv16i8(<vscale x 16 x i8> %0, ptr %p) {
6+
; CHECK-LABEL: formal_argument_nxv16i8:
7+
; CHECK: // %bb.0:
8+
; CHECK-NEXT: ptrue p0.b
9+
; CHECK-NEXT: st1b { z0.b }, p0, [x0]
10+
; CHECK-NEXT: ret
11+
store <vscale x 16 x i8> %0, ptr %p, align 16
12+
ret void
13+
}
14+
15+
define void @formal_argument_nxv8i16(<vscale x 8 x i16> %0, ptr %p) {
16+
; CHECK-LABEL: formal_argument_nxv8i16:
17+
; CHECK: // %bb.0:
18+
; CHECK-NEXT: ptrue p0.h
19+
; CHECK-NEXT: st1h { z0.h }, p0, [x0]
20+
; CHECK-NEXT: ret
21+
store <vscale x 8 x i16> %0, ptr %p, align 16
22+
ret void
23+
}
24+
25+
define void @formal_argument_nxv4i32(<vscale x 4 x i32> %0, ptr %p) {
26+
; CHECK-LABEL: formal_argument_nxv4i32:
27+
; CHECK: // %bb.0:
28+
; CHECK-NEXT: ptrue p0.s
29+
; CHECK-NEXT: st1w { z0.s }, p0, [x0]
30+
; CHECK-NEXT: ret
31+
store <vscale x 4 x i32> %0, ptr %p, align 16
32+
ret void
33+
}
34+
35+
define void @formal_argument_nxv2i64(<vscale x 2 x i64> %0, ptr %p) {
36+
; CHECK-LABEL: formal_argument_nxv2i64:
37+
; CHECK: // %bb.0:
38+
; CHECK-NEXT: ptrue p0.d
39+
; CHECK-NEXT: st1d { z0.d }, p0, [x0]
40+
; CHECK-NEXT: ret
41+
store <vscale x 2 x i64> %0, ptr %p, align 16
42+
ret void
43+
}
44+
45+
;; TODO: Add tests for other types when store is supported for them.

0 commit comments

Comments
 (0)