Skip to content

Commit 8bce40b

Browse files
authored
[AArch64][GISel] Support SVE with 128-bit min-size for G_LOAD and G_STORE (#92130)
This patch adds basic support for scalable vector types in load & store instructions for AArch64 with GISel. Only scalable vector types with a 128-bit base size are supported, e.g. `<vscale x 4 x i32>`, `<vscale x 16 x i8>`. This patch adapted some ideas from a similar abandoned patch [https://github.com/llvm/llvm-project/pull/72976](https://github.com/llvm/llvm-project/pull/72976).
1 parent 67897d7 commit 8bce40b

File tree

9 files changed

+128
-22
lines changed

9 files changed

+128
-22
lines changed

llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -669,17 +669,17 @@ bool GIMatchTableExecutor::executeMatchTable(
669669
MachineMemOperand *MMO =
670670
*(State.MIs[InsnID]->memoperands_begin() + MMOIdx);
671671

672-
unsigned Size = MRI.getType(MO.getReg()).getSizeInBits();
672+
const TypeSize Size = MRI.getType(MO.getReg()).getSizeInBits();
673673
if (MatcherOpcode == GIM_CheckMemorySizeEqualToLLT &&
674-
MMO->getSizeInBits().getValue() != Size) {
674+
MMO->getSizeInBits() != Size) {
675675
if (handleReject() == RejectAndGiveUp)
676676
return false;
677677
} else if (MatcherOpcode == GIM_CheckMemorySizeLessThanLLT &&
678-
MMO->getSizeInBits().getValue() >= Size) {
678+
TypeSize::isKnownGE(MMO->getSizeInBits().getValue(), Size)) {
679679
if (handleReject() == RejectAndGiveUp)
680680
return false;
681681
} else if (MatcherOpcode == GIM_CheckMemorySizeGreaterThanLLT &&
682-
MMO->getSizeInBits().getValue() <= Size)
682+
TypeSize::isKnownLE(MMO->getSizeInBits().getValue(), Size))
683683
if (handleReject() == RejectAndGiveUp)
684684
return false;
685685

llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1150,7 +1150,8 @@ bool CombinerHelper::isIndexedLoadStoreLegal(GLoadStore &LdSt) const {
11501150
LLT Ty = MRI.getType(LdSt.getReg(0));
11511151
LLT MemTy = LdSt.getMMO().getMemoryType();
11521152
SmallVector<LegalityQuery::MemDesc, 2> MemDescrs(
1153-
{{MemTy, MemTy.getSizeInBits(), AtomicOrdering::NotAtomic}});
1153+
{{MemTy, MemTy.getSizeInBits().getKnownMinValue(),
1154+
AtomicOrdering::NotAtomic}});
11541155
unsigned IndexedOpc = getIndexedOpc(LdSt.getOpcode());
11551156
SmallVector<LLT> OpTys;
11561157
if (IndexedOpc == TargetOpcode::G_INDEXED_STORE)

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1413,7 +1413,7 @@ bool IRTranslator::translateLoad(const User &U, MachineIRBuilder &MIRBuilder) {
14131413

14141414
bool IRTranslator::translateStore(const User &U, MachineIRBuilder &MIRBuilder) {
14151415
const StoreInst &SI = cast<StoreInst>(U);
1416-
if (DL->getTypeStoreSize(SI.getValueOperand()->getType()) == 0)
1416+
if (DL->getTypeStoreSize(SI.getValueOperand()->getType()).isZero())
14171417
return true;
14181418

14191419
ArrayRef<Register> Vals = getOrCreateVRegs(*SI.getValueOperand());

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,15 @@ static cl::opt<bool> EnableExtToTBL("aarch64-enable-ext-to-tbl", cl::Hidden,
145145
static cl::opt<unsigned> MaxXors("aarch64-max-xors", cl::init(16), cl::Hidden,
146146
cl::desc("Maximum of xors"));
147147

148+
// By turning this on, we will not fallback to DAG ISel when encountering
149+
// scalable vector types for all instruction, even if SVE is not yet supported
150+
// with some instructions.
151+
// See [AArch64TargetLowering::fallbackToDAGISel] for implementation details.
152+
static cl::opt<bool> EnableSVEGISel(
153+
"aarch64-enable-gisel-sve", cl::Hidden,
154+
cl::desc("Enable / disable SVE scalable vectors in Global ISel"),
155+
cl::init(false));
156+
148157
/// Value type used for condition codes.
149158
static const MVT MVT_CC = MVT::i32;
150159

@@ -26469,16 +26478,22 @@ bool AArch64TargetLowering::shouldLocalize(
2646926478
}
2647026479

2647126480
bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
26472-
if (Inst.getType()->isScalableTy())
26473-
return true;
26474-
26475-
for (unsigned i = 0; i < Inst.getNumOperands(); ++i)
26476-
if (Inst.getOperand(i)->getType()->isScalableTy())
26481+
// Fallback for scalable vectors.
26482+
// Note that if EnableSVEGISel is true, we allow scalable vector types for
26483+
// all instructions, regardless of whether they are actually supported.
26484+
if (!EnableSVEGISel) {
26485+
if (Inst.getType()->isScalableTy()) {
2647726486
return true;
26487+
}
2647826488

26479-
if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {
26480-
if (AI->getAllocatedType()->isScalableTy())
26481-
return true;
26489+
for (unsigned i = 0; i < Inst.getNumOperands(); ++i)
26490+
if (Inst.getOperand(i)->getType()->isScalableTy())
26491+
return true;
26492+
26493+
if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {
26494+
if (AI->getAllocatedType()->isScalableTy())
26495+
return true;
26496+
}
2648226497
}
2648326498

2648426499
// Checks to allow the use of SME instructions

llvm/lib/Target/AArch64/AArch64RegisterBanks.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
/// General Purpose Registers: W, X.
1313
def GPRRegBank : RegisterBank<"GPR", [XSeqPairsClass]>;
1414

15-
/// Floating Point/Vector Registers: B, H, S, D, Q.
16-
def FPRRegBank : RegisterBank<"FPR", [QQQQ]>;
15+
/// Floating Point, Vector, Scalable Vector Registers: B, H, S, D, Q, Z.
16+
def FPRRegBank : RegisterBank<"FPR", [QQQQ, ZPR]>;
1717

1818
/// Conditional register: NZCV.
1919
def CCRegBank : RegisterBank<"CC", [CCR]>;

llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
6161
const LLT v2s64 = LLT::fixed_vector(2, 64);
6262
const LLT v2p0 = LLT::fixed_vector(2, p0);
6363

64+
const LLT nxv16s8 = LLT::scalable_vector(16, s8);
65+
const LLT nxv8s16 = LLT::scalable_vector(8, s16);
66+
const LLT nxv4s32 = LLT::scalable_vector(4, s32);
67+
const LLT nxv2s64 = LLT::scalable_vector(2, s64);
68+
6469
std::initializer_list<LLT> PackedVectorAllTypeList = {/* Begin 128bit types */
6570
v16s8, v8s16, v4s32,
6671
v2s64, v2p0,
@@ -328,7 +333,31 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
328333
return ValTy.isPointerVector() && ValTy.getAddressSpace() == 0;
329334
};
330335

331-
getActionDefinitionsBuilder(G_LOAD)
336+
auto &LoadActions = getActionDefinitionsBuilder(G_LOAD);
337+
auto &StoreActions = getActionDefinitionsBuilder(G_STORE);
338+
339+
if (ST.hasSVE()) {
340+
LoadActions.legalForTypesWithMemDesc({
341+
// 128 bit base sizes
342+
{nxv16s8, p0, nxv16s8, 8},
343+
{nxv8s16, p0, nxv8s16, 8},
344+
{nxv4s32, p0, nxv4s32, 8},
345+
{nxv2s64, p0, nxv2s64, 8},
346+
});
347+
348+
// TODO: Add nxv2p0. Consider bitcastIf.
349+
// See #92130
350+
// https://github.com/llvm/llvm-project/pull/92130#discussion_r1616888461
351+
StoreActions.legalForTypesWithMemDesc({
352+
// 128 bit base sizes
353+
{nxv16s8, p0, nxv16s8, 8},
354+
{nxv8s16, p0, nxv8s16, 8},
355+
{nxv4s32, p0, nxv4s32, 8},
356+
{nxv2s64, p0, nxv2s64, 8},
357+
});
358+
}
359+
360+
LoadActions
332361
.customIf([=](const LegalityQuery &Query) {
333362
return HasRCPC3 && Query.Types[0] == s128 &&
334363
Query.MMODescrs[0].Ordering == AtomicOrdering::Acquire;
@@ -378,7 +407,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
378407
.customIf(IsPtrVecPred)
379408
.scalarizeIf(typeInSet(0, {v2s16, v2s8}), 0);
380409

381-
getActionDefinitionsBuilder(G_STORE)
410+
StoreActions
382411
.customIf([=](const LegalityQuery &Query) {
383412
return HasRCPC3 && Query.Types[0] == s128 &&
384413
Query.MMODescrs[0].Ordering == AtomicOrdering::Release;

llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,8 @@ bool matchSplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI) {
309309
if (!Store.isSimple())
310310
return false;
311311
LLT ValTy = MRI.getType(Store.getValueReg());
312+
if (ValTy.isScalableVector())
313+
return false;
312314
if (!ValTy.isVector() || ValTy.getSizeInBits() != 128)
313315
return false;
314316
if (Store.getMemSizeInBits() != ValTy.getSizeInBits())
@@ -708,6 +710,11 @@ bool AArch64PostLegalizerCombiner::optimizeConsecutiveMemOpAddressing(
708710
// should only be in a single block.
709711
resetState();
710712
for (auto &MI : MBB) {
713+
// Skip for scalable vectors
714+
if (auto *LdSt = dyn_cast<GLoadStore>(&MI);
715+
LdSt && MRI.getType(LdSt->getOperand(0).getReg()).isScalableVector())
716+
continue;
717+
711718
if (auto *St = dyn_cast<GStore>(&MI)) {
712719
Register PtrBaseReg;
713720
APInt Offset;

llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ AArch64RegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
257257
case AArch64::QQRegClassID:
258258
case AArch64::QQQRegClassID:
259259
case AArch64::QQQQRegClassID:
260+
case AArch64::ZPRRegClassID:
260261
return getRegBank(AArch64::FPRRegBankID);
261262
case AArch64::GPR32commonRegClassID:
262263
case AArch64::GPR32RegClassID:
@@ -743,12 +744,15 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
743744
LLT Ty = MRI.getType(MO.getReg());
744745
if (!Ty.isValid())
745746
continue;
746-
OpSize[Idx] = Ty.getSizeInBits();
747+
OpSize[Idx] = Ty.getSizeInBits().getKnownMinValue();
747748

748-
// As a top-level guess, vectors go in FPRs, scalars and pointers in GPRs.
749+
// As a top-level guess, vectors including both scalable and non-scalable
750+
// ones go in FPRs, scalars and pointers in GPRs.
749751
// For floating-point instructions, scalars go in FPRs.
750-
if (Ty.isVector() || isPreISelGenericFloatingPointOpcode(Opc) ||
751-
Ty.getSizeInBits() > 64)
752+
if (Ty.isVector())
753+
OpRegBankIdx[Idx] = PMI_FirstFPR;
754+
else if (isPreISelGenericFloatingPointOpcode(Opc) ||
755+
Ty.getSizeInBits() > 64)
752756
OpRegBankIdx[Idx] = PMI_FirstFPR;
753757
else
754758
OpRegBankIdx[Idx] = PMI_FirstGPR;
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
2+
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve -global-isel -aarch64-enable-gisel-sve=true < %s | FileCheck %s
3+
4+
define void @scalable_v16i8(ptr %l0, ptr %l1) {
5+
; CHECK-LABEL: scalable_v16i8:
6+
; CHECK: // %bb.0:
7+
; CHECK-NEXT: ptrue p0.b
8+
; CHECK-NEXT: ld1b { z0.b }, p0/z, [x0]
9+
; CHECK-NEXT: st1b { z0.b }, p0, [x1]
10+
; CHECK-NEXT: ret
11+
%l3 = load <vscale x 16 x i8>, ptr %l0, align 16
12+
store <vscale x 16 x i8> %l3, ptr %l1, align 16
13+
ret void
14+
}
15+
16+
define void @scalable_v8i16(ptr %l0, ptr %l1) {
17+
; CHECK-LABEL: scalable_v8i16:
18+
; CHECK: // %bb.0:
19+
; CHECK-NEXT: ptrue p0.h
20+
; CHECK-NEXT: ld1h { z0.h }, p0/z, [x0]
21+
; CHECK-NEXT: st1h { z0.h }, p0, [x1]
22+
; CHECK-NEXT: ret
23+
%l3 = load <vscale x 8 x i16>, ptr %l0, align 16
24+
store <vscale x 8 x i16> %l3, ptr %l1, align 16
25+
ret void
26+
}
27+
28+
define void @scalable_v4i32(ptr %l0, ptr %l1) {
29+
; CHECK-LABEL: scalable_v4i32:
30+
; CHECK: // %bb.0:
31+
; CHECK-NEXT: ptrue p0.s
32+
; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0]
33+
; CHECK-NEXT: st1w { z0.s }, p0, [x1]
34+
; CHECK-NEXT: ret
35+
%l3 = load <vscale x 4 x i32>, ptr %l0, align 16
36+
store <vscale x 4 x i32> %l3, ptr %l1, align 16
37+
ret void
38+
}
39+
40+
define void @scalable_v2i64(ptr %l0, ptr %l1) {
41+
; CHECK-LABEL: scalable_v2i64:
42+
; CHECK: // %bb.0:
43+
; CHECK-NEXT: ptrue p0.d
44+
; CHECK-NEXT: ld1d { z0.d }, p0/z, [x0]
45+
; CHECK-NEXT: st1d { z0.d }, p0, [x1]
46+
; CHECK-NEXT: ret
47+
%l3 = load <vscale x 2 x i64>, ptr %l0, align 16
48+
store <vscale x 2 x i64> %l3, ptr %l1, align 16
49+
ret void
50+
}

0 commit comments

Comments
 (0)