Skip to content

Commit a61fb1a

Browse files
committed
[AArch64][GISel] Support SVE with 128-bit min-size for G_LOAD and G_STORE
This patch adds basic support for scalable vector types in load & store instructions for AArch64 with GISel. Only scalable vector types with a 128-bit base size are supported, e.g. <vscale x 4 x i32>, <vscale x 16 x i8>. This patch adapted some ideas from a similar abandoned patch llvm#72976.
1 parent 1a49810 commit a61fb1a

File tree

10 files changed

+221
-25
lines changed

10 files changed

+221
-25
lines changed

llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -652,17 +652,17 @@ bool GIMatchTableExecutor::executeMatchTable(
652652
MachineMemOperand *MMO =
653653
*(State.MIs[InsnID]->memoperands_begin() + MMOIdx);
654654

655-
unsigned Size = MRI.getType(MO.getReg()).getSizeInBits();
655+
const auto Size = MRI.getType(MO.getReg()).getSizeInBits();
656656
if (MatcherOpcode == GIM_CheckMemorySizeEqualToLLT &&
657-
MMO->getSizeInBits().getValue() != Size) {
657+
MMO->getSizeInBits() != Size) {
658658
if (handleReject() == RejectAndGiveUp)
659659
return false;
660660
} else if (MatcherOpcode == GIM_CheckMemorySizeLessThanLLT &&
661-
MMO->getSizeInBits().getValue() >= Size) {
661+
MMO->getSizeInBits().getValue() >= Size.getKnownMinValue()) {
662662
if (handleReject() == RejectAndGiveUp)
663663
return false;
664664
} else if (MatcherOpcode == GIM_CheckMemorySizeGreaterThanLLT &&
665-
MMO->getSizeInBits().getValue() <= Size)
665+
MMO->getSizeInBits().getValue() <= Size.getKnownMinValue())
666666
if (handleReject() == RejectAndGiveUp)
667667
return false;
668668

llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1080,7 +1080,7 @@ bool CombinerHelper::isIndexedLoadStoreLegal(GLoadStore &LdSt) const {
10801080
LLT Ty = MRI.getType(LdSt.getReg(0));
10811081
LLT MemTy = LdSt.getMMO().getMemoryType();
10821082
SmallVector<LegalityQuery::MemDesc, 2> MemDescrs(
1083-
{{MemTy, MemTy.getSizeInBits(), AtomicOrdering::NotAtomic}});
1083+
{{MemTy, MemTy.getSizeInBits().getKnownMinValue(), AtomicOrdering::NotAtomic}});
10841084
unsigned IndexedOpc = getIndexedOpc(LdSt.getOpcode());
10851085
SmallVector<LLT> OpTys;
10861086
if (IndexedOpc == TargetOpcode::G_INDEXED_STORE)

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1413,7 +1413,7 @@ bool IRTranslator::translateLoad(const User &U, MachineIRBuilder &MIRBuilder) {
14131413

14141414
bool IRTranslator::translateStore(const User &U, MachineIRBuilder &MIRBuilder) {
14151415
const StoreInst &SI = cast<StoreInst>(U);
1416-
if (DL->getTypeStoreSize(SI.getValueOperand()->getType()) == 0)
1416+
if (DL->getTypeStoreSize(SI.getValueOperand()->getType()).isZero())
14171417
return true;
14181418

14191419
ArrayRef<Register> Vals = getOrCreateVRegs(*SI.getValueOperand());

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26375,12 +26375,20 @@ bool AArch64TargetLowering::shouldLocalize(
2637526375
return TargetLoweringBase::shouldLocalize(MI, TTI);
2637626376
}
2637726377

26378+
static bool isScalableTySupported(const unsigned Op) {
26379+
return Op == Instruction::Load || Op == Instruction::Store;
26380+
}
26381+
2637826382
bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
26379-
if (Inst.getType()->isScalableTy())
26380-
return true;
26383+
const auto ScalableTySupported = isScalableTySupported(Inst.getOpcode());
26384+
26385+
// Fallback for scalable vectors
26386+
if (Inst.getType()->isScalableTy() && !ScalableTySupported) {
26387+
return true;
26388+
}
2638126389

2638226390
for (unsigned i = 0; i < Inst.getNumOperands(); ++i)
26383-
if (Inst.getOperand(i)->getType()->isScalableTy())
26391+
if (Inst.getOperand(i)->getType()->isScalableTy() && !ScalableTySupported)
2638426392
return true;
2638526393

2638626394
if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {

llvm/lib/Target/AArch64/AArch64RegisterBanks.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
def GPRRegBank : RegisterBank<"GPR", [XSeqPairsClass]>;
1414

1515
/// Floating Point/Vector Registers: B, H, S, D, Q.
16-
def FPRRegBank : RegisterBank<"FPR", [QQQQ]>;
16+
def FPRRegBank : RegisterBank<"FPR", [QQQQ, ZPR]>;
1717

1818
/// Conditional register: NZCV.
1919
def CCRegBank : RegisterBank<"CC", [CCR]>;

llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,27 @@ static unsigned selectLoadStoreUIOp(unsigned GenericOpc, unsigned RegBankID,
901901
return GenericOpc;
902902
}
903903

904+
/// Select the AArch64 opcode for the G_LOAD or G_STORE operation for scalable
905+
/// vectors.
906+
/// \p ElementSize size of the element of the scalable vector
907+
static unsigned selectLoadStoreSVEOp(const unsigned GenericOpc,
908+
const unsigned ElementSize) {
909+
const bool isStore = GenericOpc == TargetOpcode::G_STORE;
910+
911+
switch (ElementSize) {
912+
case 8:
913+
return isStore ? AArch64::ST1B : AArch64::LD1B;
914+
case 16:
915+
return isStore ? AArch64::ST1H : AArch64::LD1H;
916+
case 32:
917+
return isStore ? AArch64::ST1W : AArch64::LD1W;
918+
case 64:
919+
return isStore ? AArch64::ST1D : AArch64::LD1D;
920+
}
921+
922+
return GenericOpc;
923+
}
924+
904925
/// Helper function for selectCopy. Inserts a subregister copy from \p SrcReg
905926
/// to \p *To.
906927
///
@@ -2853,8 +2874,8 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
28532874
return false;
28542875
}
28552876

2856-
uint64_t MemSizeInBytes = LdSt.getMemSize().getValue();
2857-
unsigned MemSizeInBits = LdSt.getMemSizeInBits().getValue();
2877+
uint64_t MemSizeInBytes = LdSt.getMemSize().getValue().getKnownMinValue();
2878+
unsigned MemSizeInBits = LdSt.getMemSizeInBits().getValue().getKnownMinValue();
28582879
AtomicOrdering Order = LdSt.getMMO().getSuccessOrdering();
28592880

28602881
// Need special instructions for atomics that affect ordering.
@@ -2906,9 +2927,23 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
29062927
const LLT ValTy = MRI.getType(ValReg);
29072928
const RegisterBank &RB = *RBI.getRegBank(ValReg, MRI, TRI);
29082929

2930+
#ifndef NDEBUG
2931+
if (ValTy.isScalableVector()) {
2932+
assert(STI.hasSVE()
2933+
&& "Load/Store register operand is scalable vector "
2934+
"while SVE is not supported by the target");
2935+
// assert(RB.getID() == AArch64::SVRRegBankID
2936+
// && "Load/Store register operand is scalable vector "
2937+
// "while its register bank is not SVR");
2938+
}
2939+
#endif
2940+
29092941
// The code below doesn't support truncating stores, so we need to split it
29102942
// again.
2911-
if (isa<GStore>(LdSt) && ValTy.getSizeInBits() > MemSizeInBits) {
2943+
// Truncate only if type is not scalable vector
2944+
const bool NeedTrunc = !ValTy.isScalableVector()
2945+
&& ValTy.getSizeInBits().getFixedValue() > MemSizeInBits;
2946+
if (isa<GStore>(LdSt) && NeedTrunc) {
29122947
unsigned SubReg;
29132948
LLT MemTy = LdSt.getMMO().getMemoryType();
29142949
auto *RC = getRegClassForTypeOnBank(MemTy, RB);
@@ -2921,7 +2956,7 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
29212956
.getReg(0);
29222957
RBI.constrainGenericRegister(Copy, *RC, MRI);
29232958
LdSt.getOperand(0).setReg(Copy);
2924-
} else if (isa<GLoad>(LdSt) && ValTy.getSizeInBits() > MemSizeInBits) {
2959+
} else if (isa<GLoad>(LdSt) && NeedTrunc) {
29252960
// If this is an any-extending load from the FPR bank, split it into a regular
29262961
// load + extend.
29272962
if (RB.getID() == AArch64::FPRRegBankID) {
@@ -2951,10 +2986,19 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
29512986
// instruction with an updated opcode, or a new instruction.
29522987
auto SelectLoadStoreAddressingMode = [&]() -> MachineInstr * {
29532988
bool IsStore = isa<GStore>(I);
2954-
const unsigned NewOpc =
2955-
selectLoadStoreUIOp(I.getOpcode(), RB.getID(), MemSizeInBits);
2989+
unsigned NewOpc;
2990+
if (ValTy.isScalableVector()) {
2991+
NewOpc = selectLoadStoreSVEOp(I.getOpcode(), ValTy.getElementType().getSizeInBits());
2992+
} else {
2993+
NewOpc = selectLoadStoreUIOp(I.getOpcode(), RB.getID(), MemSizeInBits);
2994+
}
29562995
if (NewOpc == I.getOpcode())
29572996
return nullptr;
2997+
2998+
if (ValTy.isScalableVector()) {
2999+
// Add the predicate register operand
3000+
I.addOperand(MachineOperand::CreatePredicate(true));
3001+
}
29583002
// Check if we can fold anything into the addressing mode.
29593003
auto AddrModeFns =
29603004
selectAddrModeIndexed(I.getOperand(1), MemSizeInBytes);
@@ -2970,6 +3014,9 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
29703014
Register CurValReg = I.getOperand(0).getReg();
29713015
IsStore ? NewInst.addUse(CurValReg) : NewInst.addDef(CurValReg);
29723016
NewInst.cloneMemRefs(I);
3017+
if (ValTy.isScalableVector()) {
3018+
NewInst.add(I.getOperand(1)); // Copy predicate register
3019+
}
29733020
for (auto &Fn : *AddrModeFns)
29743021
Fn(NewInst);
29753022
I.eraseFromParent();

llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,79 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
6161
const LLT v2s64 = LLT::fixed_vector(2, 64);
6262
const LLT v2p0 = LLT::fixed_vector(2, p0);
6363

64+
// Scalable vector sizes range from 128 to 2048
65+
// Note that subtargets may not support the full range.
66+
// See [ScalableVecTypes] below.
67+
const LLT nxv16s8 = LLT::scalable_vector(16, s8);
68+
const LLT nxv32s8 = LLT::scalable_vector(32, s8);
69+
const LLT nxv64s8 = LLT::scalable_vector(64, s8);
70+
const LLT nxv128s8 = LLT::scalable_vector(128, s8);
71+
const LLT nxv256s8 = LLT::scalable_vector(256, s8);
72+
73+
const LLT nxv8s16 = LLT::scalable_vector(8, s16);
74+
const LLT nxv16s16 = LLT::scalable_vector(16, s16);
75+
const LLT nxv32s16 = LLT::scalable_vector(32, s16);
76+
const LLT nxv64s16 = LLT::scalable_vector(64, s16);
77+
const LLT nxv128s16 = LLT::scalable_vector(128, s16);
78+
79+
const LLT nxv4s32 = LLT::scalable_vector(4, s32);
80+
const LLT nxv8s32 = LLT::scalable_vector(8, s32);
81+
const LLT nxv16s32 = LLT::scalable_vector(16, s32);
82+
const LLT nxv32s32 = LLT::scalable_vector(32, s32);
83+
const LLT nxv64s32 = LLT::scalable_vector(64, s32);
84+
85+
const LLT nxv2s64 = LLT::scalable_vector(2, s64);
86+
const LLT nxv4s64 = LLT::scalable_vector(4, s64);
87+
const LLT nxv8s64 = LLT::scalable_vector(8, s64);
88+
const LLT nxv16s64 = LLT::scalable_vector(16, s64);
89+
const LLT nxv32s64 = LLT::scalable_vector(32, s64);
90+
91+
const LLT nxv2p0 = LLT::scalable_vector(2, p0);
92+
const LLT nxv4p0 = LLT::scalable_vector(4, p0);
93+
const LLT nxv8p0 = LLT::scalable_vector(8, p0);
94+
const LLT nxv16p0 = LLT::scalable_vector(16, p0);
95+
const LLT nxv32p0 = LLT::scalable_vector(32, p0);
96+
97+
const auto ScalableVec128 = {
98+
nxv16s8, nxv8s16, nxv4s32, nxv2s64, nxv2p0,
99+
};
100+
const auto ScalableVec256 = {
101+
nxv32s8, nxv16s16, nxv8s32, nxv4s64, nxv4p0,
102+
};
103+
const auto ScalableVec512 = {
104+
nxv64s8, nxv32s16, nxv16s32, nxv8s64, nxv8p0,
105+
};
106+
const auto ScalableVec1024 = {
107+
nxv128s8, nxv64s16, nxv32s32, nxv16s64, nxv16p0,
108+
};
109+
const auto ScalableVec2048 = {
110+
nxv256s8, nxv128s16, nxv64s32, nxv32s64, nxv32p0,
111+
};
112+
113+
/// Scalable vector types supported by the sub target.
114+
/// Empty if SVE is not supported.
115+
SmallVector<LLT> ScalableVecTypes;
116+
117+
if (ST.hasSVE()) {
118+
// Add scalable vector types that are supported by the subtarget
119+
const auto MinSize = ST.getMinSVEVectorSizeInBits();
120+
auto MaxSize = ST.getMaxSVEVectorSizeInBits();
121+
if (MaxSize == 0) {
122+
// Unknown max size, assume the target supports all sizes.
123+
MaxSize = 2048;
124+
}
125+
if (MinSize <= 128 && 128 <= MaxSize)
126+
ScalableVecTypes.append(ScalableVec128);
127+
if (MinSize <= 256 && 256 <= MaxSize)
128+
ScalableVecTypes.append(ScalableVec256);
129+
if (MinSize <= 512 && 512 <= MaxSize)
130+
ScalableVecTypes.append(ScalableVec512);
131+
if (MinSize <= 1024 && 1024 <= MaxSize)
132+
ScalableVecTypes.append(ScalableVec1024);
133+
if (MinSize <= 2048 && 2048 <= MaxSize)
134+
ScalableVecTypes.append(ScalableVec2048);
135+
}
136+
64137
std::initializer_list<LLT> PackedVectorAllTypeList = {/* Begin 128bit types */
65138
v16s8, v8s16, v4s32,
66139
v2s64, v2p0,
@@ -329,6 +402,18 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
329402
return ValTy.isPointerVector() && ValTy.getAddressSpace() == 0;
330403
};
331404

405+
const auto IsSameScalableVecTy = [=](const LegalityQuery &Query) {
406+
// Legal if loading a scalable vector type
407+
// into a scalable vector register of the exactly same type
408+
if (!Query.Types[0].isScalableVector() || Query.Types[1] != p0)
409+
return false;
410+
if (Query.MMODescrs[0].MemoryTy != Query.Types[0])
411+
return false;
412+
if (Query.MMODescrs[0].AlignInBits < 128)
413+
return false;
414+
return is_contained(ScalableVecTypes, Query.Types[0]);
415+
};
416+
332417
getActionDefinitionsBuilder(G_LOAD)
333418
.customIf([=](const LegalityQuery &Query) {
334419
return HasRCPC3 && Query.Types[0] == s128 &&
@@ -354,6 +439,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
354439
// These extends are also legal
355440
.legalForTypesWithMemDesc(
356441
{{s32, p0, s8, 8}, {s32, p0, s16, 8}, {s64, p0, s32, 8}})
442+
.legalIf(IsSameScalableVecTy)
357443
.widenScalarToNextPow2(0, /* MinSize = */ 8)
358444
.clampMaxNumElements(0, s8, 16)
359445
.clampMaxNumElements(0, s16, 8)
@@ -398,7 +484,9 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
398484
{s64, p0, s64, 8}, {s64, p0, s32, 8}, // truncstorei32 from s64
399485
{p0, p0, s64, 8}, {s128, p0, s128, 8}, {v16s8, p0, s128, 8},
400486
{v8s8, p0, s64, 8}, {v4s16, p0, s64, 8}, {v8s16, p0, s128, 8},
401-
{v2s32, p0, s64, 8}, {v4s32, p0, s128, 8}, {v2s64, p0, s128, 8}})
487+
{v2s32, p0, s64, 8}, {v4s32, p0, s128, 8}, {v2s64, p0, s128, 8},
488+
})
489+
.legalIf(IsSameScalableVecTy)
402490
.clampScalar(0, s8, s64)
403491
.lowerIf([=](const LegalityQuery &Query) {
404492
return Query.Types[0].isScalar() &&
@@ -440,8 +528,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
440528
{p0, v4s32, v4s32, 8},
441529
{p0, v2s64, v2s64, 8},
442530
{p0, v2p0, v2p0, 8},
443-
{p0, s128, s128, 8},
444-
})
531+
{p0, s128, s128, 8}})
445532
.unsupported();
446533

447534
auto IndexedLoadBasicPred = [=](const LegalityQuery &Query) {

llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ bool matchSplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI) {
309309
if (!Store.isSimple())
310310
return false;
311311
LLT ValTy = MRI.getType(Store.getValueReg());
312-
if (!ValTy.isVector() || ValTy.getSizeInBits() != 128)
312+
if (!ValTy.isVector() || ValTy.getSizeInBits().getKnownMinValue() != 128)
313313
return false;
314314
if (Store.getMemSizeInBits() != ValTy.getSizeInBits())
315315
return false; // Don't split truncating stores.
@@ -657,8 +657,8 @@ bool AArch64PostLegalizerCombiner::optimizeConsecutiveMemOpAddressing(
657657
Register PtrBaseReg;
658658
APInt Offset;
659659
LLT StoredValTy = MRI.getType(St->getValueReg());
660-
unsigned ValSize = StoredValTy.getSizeInBits();
661-
if (ValSize < 32 || St->getMMO().getSizeInBits() != ValSize)
660+
const auto ValSize = StoredValTy.getSizeInBits();
661+
if (ValSize.getKnownMinValue() < 32 || St->getMMO().getSizeInBits() != ValSize)
662662
continue;
663663

664664
Register PtrReg = St->getPointerReg();

llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ AArch64RegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
257257
case AArch64::QQRegClassID:
258258
case AArch64::QQQRegClassID:
259259
case AArch64::QQQQRegClassID:
260+
case AArch64::ZPRRegClassID:
260261
return getRegBank(AArch64::FPRRegBankID);
261262
case AArch64::GPR32commonRegClassID:
262263
case AArch64::GPR32RegClassID:
@@ -740,11 +741,14 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
740741
LLT Ty = MRI.getType(MO.getReg());
741742
if (!Ty.isValid())
742743
continue;
743-
OpSize[Idx] = Ty.getSizeInBits();
744+
OpSize[Idx] = Ty.getSizeInBits().getKnownMinValue();
744745

745-
// As a top-level guess, vectors go in FPRs, scalars and pointers in GPRs.
746+
// As a top-level guess, scalable vectors go in SVRs, non-scalable
747+
// vectors go in FPRs, scalars and pointers in GPRs.
746748
// For floating-point instructions, scalars go in FPRs.
747-
if (Ty.isVector() || isPreISelGenericFloatingPointOpcode(Opc) ||
749+
if (Ty.isScalableVector())
750+
OpRegBankIdx[Idx] = PMI_FirstFPR;
751+
else if (Ty.isVector() || isPreISelGenericFloatingPointOpcode(Opc) ||
748752
Ty.getSizeInBits() > 64)
749753
OpRegBankIdx[Idx] = PMI_FirstFPR;
750754
else

0 commit comments

Comments
 (0)