Skip to content

Commit 19999f2

Browse files
AMDGPU/GlobalISel: RegBankLegalize rules for load
Add IDs for bit width that cover multiple LLTs: B32 B64 etc. "Predicate" wrapper class for bool predicate functions used to write pretty rules. Predicates can be combined using &&, || and !. Lowering for splitting and widening loads. Write rules for loads to not change existing mir tests from old regbankselect.
1 parent 3f085e7 commit 19999f2

File tree

6 files changed

+941
-66
lines changed

6 files changed

+941
-66
lines changed

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp

Lines changed: 294 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,97 @@ bool RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) {
3737
return true;
3838
}
3939

40+
void RegBankLegalizeHelper::splitLoad(MachineInstr &MI,
41+
ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
42+
MachineFunction &MF = B.getMF();
43+
assert(MI.getNumMemOperands() == 1);
44+
MachineMemOperand &BaseMMO = **MI.memoperands_begin();
45+
Register Dst = MI.getOperand(0).getReg();
46+
const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
47+
Register BasePtrReg = MI.getOperand(1).getReg();
48+
LLT PtrTy = MRI.getType(BasePtrReg);
49+
const RegisterBank *PtrRB = MRI.getRegBankOrNull(BasePtrReg);
50+
LLT OffsetTy = LLT::scalar(PtrTy.getSizeInBits());
51+
SmallVector<Register, 4> LoadPartRegs;
52+
53+
unsigned ByteOffset = 0;
54+
for (LLT PartTy : LLTBreakdown) {
55+
Register BasePtrPlusOffsetReg;
56+
if (ByteOffset == 0) {
57+
BasePtrPlusOffsetReg = BasePtrReg;
58+
} else {
59+
BasePtrPlusOffsetReg = MRI.createVirtualRegister({PtrRB, PtrTy});
60+
Register OffsetReg = MRI.createVirtualRegister({PtrRB, OffsetTy});
61+
B.buildConstant(OffsetReg, ByteOffset);
62+
B.buildPtrAdd(BasePtrPlusOffsetReg, BasePtrReg, OffsetReg);
63+
}
64+
MachineMemOperand *BasePtrPlusOffsetMMO =
65+
MF.getMachineMemOperand(&BaseMMO, ByteOffset, PartTy);
66+
Register PartLoad = MRI.createVirtualRegister({DstRB, PartTy});
67+
B.buildLoad(PartLoad, BasePtrPlusOffsetReg, *BasePtrPlusOffsetMMO);
68+
LoadPartRegs.push_back(PartLoad);
69+
ByteOffset += PartTy.getSizeInBytes();
70+
}
71+
72+
if (!MergeTy.isValid()) {
73+
// Loads are of same size, concat or merge them together.
74+
B.buildMergeLikeInstr(Dst, LoadPartRegs);
75+
} else {
76+
// Load(s) are not all of same size, need to unmerge them to smaller pieces
77+
// of MergeTy type, then merge them all together in Dst.
78+
SmallVector<Register, 4> MergeTyParts;
79+
for (Register Reg : LoadPartRegs) {
80+
if (MRI.getType(Reg) == MergeTy) {
81+
MergeTyParts.push_back(Reg);
82+
} else {
83+
auto Unmerge = B.buildUnmerge(MergeTy, Reg);
84+
for (unsigned i = 0; i < Unmerge->getNumOperands() - 1; ++i) {
85+
Register UnmergeReg = Unmerge->getOperand(i).getReg();
86+
MRI.setRegBank(UnmergeReg, *DstRB);
87+
MergeTyParts.push_back(UnmergeReg);
88+
}
89+
}
90+
}
91+
B.buildMergeLikeInstr(Dst, MergeTyParts);
92+
}
93+
MI.eraseFromParent();
94+
}
95+
96+
void RegBankLegalizeHelper::widenLoad(MachineInstr &MI, LLT WideTy,
97+
LLT MergeTy) {
98+
MachineFunction &MF = B.getMF();
99+
assert(MI.getNumMemOperands() == 1);
100+
MachineMemOperand &BaseMMO = **MI.memoperands_begin();
101+
Register Dst = MI.getOperand(0).getReg();
102+
const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
103+
Register BasePtrReg = MI.getOperand(1).getReg();
104+
105+
Register BasePtrPlusOffsetReg;
106+
BasePtrPlusOffsetReg = BasePtrReg;
107+
108+
MachineMemOperand *BasePtrPlusOffsetMMO =
109+
MF.getMachineMemOperand(&BaseMMO, 0, WideTy);
110+
Register WideLoad = MRI.createVirtualRegister({DstRB, WideTy});
111+
B.buildLoad(WideLoad, BasePtrPlusOffsetReg, *BasePtrPlusOffsetMMO);
112+
113+
if (WideTy.isScalar()) {
114+
B.buildTrunc(Dst, WideLoad);
115+
} else {
116+
SmallVector<Register, 4> MergeTyParts;
117+
unsigned NumEltsMerge =
118+
MRI.getType(Dst).getSizeInBits() / MergeTy.getSizeInBits();
119+
auto Unmerge = B.buildUnmerge(MergeTy, WideLoad);
120+
for (unsigned i = 0; i < Unmerge->getNumOperands() - 1; ++i) {
121+
Register UnmergeReg = Unmerge->getOperand(i).getReg();
122+
MRI.setRegBank(UnmergeReg, *DstRB);
123+
if (i < NumEltsMerge)
124+
MergeTyParts.push_back(UnmergeReg);
125+
}
126+
B.buildMergeLikeInstr(Dst, MergeTyParts);
127+
}
128+
MI.eraseFromParent();
129+
}
130+
40131
void RegBankLegalizeHelper::lower(MachineInstr &MI,
41132
const RegBankLLTMapping &Mapping,
42133
SmallSet<Register, 4> &WaterfallSGPRs) {
@@ -119,6 +210,50 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
119210
MI.eraseFromParent();
120211
break;
121212
}
213+
case SplitLoad: {
214+
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
215+
unsigned Size = DstTy.getSizeInBits();
216+
// Even split to 128-bit loads
217+
if (Size > 128) {
218+
LLT B128;
219+
if (DstTy.isVector()) {
220+
LLT EltTy = DstTy.getElementType();
221+
B128 = LLT::fixed_vector(128 / EltTy.getSizeInBits(), EltTy);
222+
} else {
223+
B128 = LLT::scalar(128);
224+
}
225+
if (Size / 128 == 2)
226+
splitLoad(MI, {B128, B128});
227+
if (Size / 128 == 4)
228+
splitLoad(MI, {B128, B128, B128, B128});
229+
}
230+
// 64 and 32 bit load
231+
else if (DstTy == S96)
232+
splitLoad(MI, {S64, S32}, S32);
233+
else if (DstTy == V3S32)
234+
splitLoad(MI, {V2S32, S32}, S32);
235+
else if (DstTy == V6S16)
236+
splitLoad(MI, {V4S16, V2S16}, V2S16);
237+
else {
238+
MI.dump();
239+
llvm_unreachable("SplitLoad type not supported");
240+
}
241+
break;
242+
}
243+
case WidenLoad: {
244+
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
245+
if (DstTy == S96)
246+
widenLoad(MI, S128);
247+
else if (DstTy == V3S32)
248+
widenLoad(MI, V4S32, S32);
249+
else if (DstTy == V6S16)
250+
widenLoad(MI, V8S16, V2S16);
251+
else {
252+
MI.dump();
253+
llvm_unreachable("WidenLoad type not supported");
254+
}
255+
break;
256+
}
122257
}
123258

124259
// TODO: executeInWaterfallLoop(... WaterfallSGPRs)
@@ -142,13 +277,73 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMapingApplyID ID) {
142277
case Sgpr64:
143278
case Vgpr64:
144279
return LLT::scalar(64);
145-
280+
case SgprP1:
281+
case VgprP1:
282+
return LLT::pointer(1, 64);
283+
case SgprP3:
284+
case VgprP3:
285+
return LLT::pointer(3, 32);
286+
case SgprP4:
287+
case VgprP4:
288+
return LLT::pointer(4, 64);
289+
case SgprP5:
290+
case VgprP5:
291+
return LLT::pointer(5, 32);
146292
case SgprV4S32:
147293
case VgprV4S32:
148294
case UniInVgprV4S32:
149295
return LLT::fixed_vector(4, 32);
150-
case VgprP1:
151-
return LLT::pointer(1, 64);
296+
default:
297+
return LLT();
298+
}
299+
}
300+
301+
LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMapingApplyID ID, LLT Ty) {
302+
switch (ID) {
303+
case SgprB32:
304+
case VgprB32:
305+
case UniInVgprB32:
306+
if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
307+
Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
308+
Ty == LLT::pointer(6, 32))
309+
return Ty;
310+
return LLT();
311+
case SgprB64:
312+
case VgprB64:
313+
case UniInVgprB64:
314+
if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
315+
Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
316+
Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64))
317+
return Ty;
318+
return LLT();
319+
case SgprB96:
320+
case VgprB96:
321+
case UniInVgprB96:
322+
if (Ty == LLT::scalar(96) || Ty == LLT::fixed_vector(3, 32) ||
323+
Ty == LLT::fixed_vector(6, 16))
324+
return Ty;
325+
return LLT();
326+
case SgprB128:
327+
case VgprB128:
328+
case UniInVgprB128:
329+
if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
330+
Ty == LLT::fixed_vector(2, 64))
331+
return Ty;
332+
return LLT();
333+
case SgprB256:
334+
case VgprB256:
335+
case UniInVgprB256:
336+
if (Ty == LLT::scalar(256) || Ty == LLT::fixed_vector(8, 32) ||
337+
Ty == LLT::fixed_vector(4, 64) || Ty == LLT::fixed_vector(16, 16))
338+
return Ty;
339+
return LLT();
340+
case SgprB512:
341+
case VgprB512:
342+
case UniInVgprB512:
343+
if (Ty == LLT::scalar(512) || Ty == LLT::fixed_vector(16, 32) ||
344+
Ty == LLT::fixed_vector(8, 64))
345+
return Ty;
346+
return LLT();
152347
default:
153348
return LLT();
154349
}
@@ -163,10 +358,26 @@ RegBankLegalizeHelper::getRBFromID(RegBankLLTMapingApplyID ID) {
163358
case Sgpr16:
164359
case Sgpr32:
165360
case Sgpr64:
361+
case SgprP1:
362+
case SgprP3:
363+
case SgprP4:
364+
case SgprP5:
166365
case SgprV4S32:
366+
case SgprB32:
367+
case SgprB64:
368+
case SgprB96:
369+
case SgprB128:
370+
case SgprB256:
371+
case SgprB512:
167372
case UniInVcc:
168373
case UniInVgprS32:
169374
case UniInVgprV4S32:
375+
case UniInVgprB32:
376+
case UniInVgprB64:
377+
case UniInVgprB96:
378+
case UniInVgprB128:
379+
case UniInVgprB256:
380+
case UniInVgprB512:
170381
case Sgpr32Trunc:
171382
case Sgpr32AExt:
172383
case Sgpr32AExtBoolInReg:
@@ -176,7 +387,16 @@ RegBankLegalizeHelper::getRBFromID(RegBankLLTMapingApplyID ID) {
176387
case Vgpr32:
177388
case Vgpr64:
178389
case VgprP1:
390+
case VgprP3:
391+
case VgprP4:
392+
case VgprP5:
179393
case VgprV4S32:
394+
case VgprB32:
395+
case VgprB64:
396+
case VgprB96:
397+
case VgprB128:
398+
case VgprB256:
399+
case VgprB512:
180400
return VgprRB;
181401

182402
default:
@@ -202,17 +422,42 @@ void RegBankLegalizeHelper::applyMappingDst(
202422
case Sgpr16:
203423
case Sgpr32:
204424
case Sgpr64:
425+
case SgprP1:
426+
case SgprP3:
427+
case SgprP4:
428+
case SgprP5:
205429
case SgprV4S32:
206430
case Vgpr32:
207431
case Vgpr64:
208432
case VgprP1:
433+
case VgprP3:
434+
case VgprP4:
435+
case VgprP5:
209436
case VgprV4S32: {
210437
assert(Ty == getTyFromID(MethodIDs[OpIdx]));
211438
assert(RB == getRBFromID(MethodIDs[OpIdx]));
212439
break;
213440
}
214441

215-
// uniform in vcc/vgpr: scalars and vectors
442+
// sgpr and vgpr B-types
443+
case SgprB32:
444+
case SgprB64:
445+
case SgprB96:
446+
case SgprB128:
447+
case SgprB256:
448+
case SgprB512:
449+
case VgprB32:
450+
case VgprB64:
451+
case VgprB96:
452+
case VgprB128:
453+
case VgprB256:
454+
case VgprB512: {
455+
assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
456+
assert(RB == getRBFromID(MethodIDs[OpIdx]));
457+
break;
458+
}
459+
460+
// uniform in vcc/vgpr: scalars, vectors and B-types
216461
case UniInVcc: {
217462
assert(Ty == S1);
218463
assert(RB == SgprRB);
@@ -229,6 +474,17 @@ void RegBankLegalizeHelper::applyMappingDst(
229474
AMDGPU::buildReadAnyLaneDst(B, MI, RBI);
230475
break;
231476
}
477+
case UniInVgprB32:
478+
case UniInVgprB64:
479+
case UniInVgprB96:
480+
case UniInVgprB128:
481+
case UniInVgprB256:
482+
case UniInVgprB512: {
483+
assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
484+
assert(RB == SgprRB);
485+
AMDGPU::buildReadAnyLaneDst(B, MI, RBI);
486+
break;
487+
}
232488

233489
// sgpr trunc
234490
case Sgpr32Trunc: {
@@ -279,16 +535,34 @@ void RegBankLegalizeHelper::applyMappingSrc(
279535
case Sgpr16:
280536
case Sgpr32:
281537
case Sgpr64:
538+
case SgprP1:
539+
case SgprP3:
540+
case SgprP4:
541+
case SgprP5:
282542
case SgprV4S32: {
283543
assert(Ty == getTyFromID(MethodIDs[i]));
284544
assert(RB == getRBFromID(MethodIDs[i]));
285545
break;
286546
}
547+
// sgpr B-types
548+
case SgprB32:
549+
case SgprB64:
550+
case SgprB96:
551+
case SgprB128:
552+
case SgprB256:
553+
case SgprB512: {
554+
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
555+
assert(RB == getRBFromID(MethodIDs[i]));
556+
break;
557+
}
287558

288559
// vgpr scalars, pointers and vectors
289560
case Vgpr32:
290561
case Vgpr64:
291562
case VgprP1:
563+
case VgprP3:
564+
case VgprP4:
565+
case VgprP5:
292566
case VgprV4S32: {
293567
assert(Ty == getTyFromID(MethodIDs[i]));
294568
if (RB != VgprRB) {
@@ -298,6 +572,21 @@ void RegBankLegalizeHelper::applyMappingSrc(
298572
}
299573
break;
300574
}
575+
// vgpr B-types
576+
case VgprB32:
577+
case VgprB64:
578+
case VgprB96:
579+
case VgprB128:
580+
case VgprB256:
581+
case VgprB512: {
582+
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
583+
if (RB != VgprRB) {
584+
auto CopyToVgpr =
585+
B.buildCopy(createVgpr(getBTyFromID(MethodIDs[i], Ty)), Reg);
586+
Op.setReg(CopyToVgpr.getReg(0));
587+
}
588+
break;
589+
}
301590

302591
// sgpr and vgpr scalars with extend
303592
case Sgpr32AExt: {
@@ -372,7 +661,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
372661
// We accept all types that can fit in some register class.
373662
// Uniform G_PHIs have all sgpr registers.
374663
// Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
375-
if (Ty == LLT::scalar(32)) {
664+
if (Ty == LLT::scalar(32) || Ty == LLT::pointer(4, 64)) {
376665
return;
377666
}
378667

0 commit comments

Comments
 (0)