Skip to content

Commit 4675f79

Browse files
AMDGPU/GlobalISel: RegBankLegalize rules for load
Add IDs for bit width that cover multiple LLTs: B32 B64 etc. "Predicate" wrapper class for bool predicate functions used to write pretty rules. Predicates can be combined using &&, || and !. Lowering for splitting and widening loads. Write rules for loads to not change existing mir tests from old regbankselect.
1 parent f7ee75a commit 4675f79

File tree

6 files changed

+929
-66
lines changed

6 files changed

+929
-66
lines changed

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp

Lines changed: 282 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,83 @@ void RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) {
3636
lower(MI, Mapping, WaterfallSgprs);
3737
}
3838

39+
void RegBankLegalizeHelper::splitLoad(MachineInstr &MI,
40+
ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
41+
MachineFunction &MF = B.getMF();
42+
assert(MI.getNumMemOperands() == 1);
43+
MachineMemOperand &BaseMMO = **MI.memoperands_begin();
44+
Register Dst = MI.getOperand(0).getReg();
45+
const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
46+
Register Base = MI.getOperand(1).getReg();
47+
LLT PtrTy = MRI.getType(Base);
48+
const RegisterBank *PtrRB = MRI.getRegBankOrNull(Base);
49+
LLT OffsetTy = LLT::scalar(PtrTy.getSizeInBits());
50+
SmallVector<Register, 4> LoadPartRegs;
51+
52+
unsigned ByteOffset = 0;
53+
for (LLT PartTy : LLTBreakdown) {
54+
Register BasePlusOffset;
55+
if (ByteOffset == 0) {
56+
BasePlusOffset = Base;
57+
} else {
58+
auto Offset = B.buildConstant({PtrRB, OffsetTy}, ByteOffset);
59+
BasePlusOffset = B.buildPtrAdd({PtrRB, PtrTy}, Base, Offset).getReg(0);
60+
}
61+
auto *OffsetMMO = MF.getMachineMemOperand(&BaseMMO, ByteOffset, PartTy);
62+
auto LoadPart = B.buildLoad({DstRB, PartTy}, BasePlusOffset, *OffsetMMO);
63+
LoadPartRegs.push_back(LoadPart.getReg(0));
64+
ByteOffset += PartTy.getSizeInBytes();
65+
}
66+
67+
if (!MergeTy.isValid()) {
68+
// Loads are of same size, concat or merge them together.
69+
B.buildMergeLikeInstr(Dst, LoadPartRegs);
70+
} else {
71+
// Loads are not all of same size, need to unmerge them to smaller pieces
72+
// of MergeTy type, then merge pieces to Dst.
73+
SmallVector<Register, 4> MergeTyParts;
74+
for (Register Reg : LoadPartRegs) {
75+
if (MRI.getType(Reg) == MergeTy) {
76+
MergeTyParts.push_back(Reg);
77+
} else {
78+
auto Unmerge = B.buildUnmerge({DstRB, MergeTy}, Reg);
79+
for (unsigned i = 0; i < Unmerge->getNumOperands() - 1; ++i)
80+
MergeTyParts.push_back(Unmerge.getReg(i));
81+
}
82+
}
83+
B.buildMergeLikeInstr(Dst, MergeTyParts);
84+
}
85+
MI.eraseFromParent();
86+
}
87+
88+
void RegBankLegalizeHelper::widenLoad(MachineInstr &MI, LLT WideTy,
89+
LLT MergeTy) {
90+
MachineFunction &MF = B.getMF();
91+
assert(MI.getNumMemOperands() == 1);
92+
MachineMemOperand &BaseMMO = **MI.memoperands_begin();
93+
Register Dst = MI.getOperand(0).getReg();
94+
const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
95+
Register Base = MI.getOperand(1).getReg();
96+
97+
MachineMemOperand *WideMMO = MF.getMachineMemOperand(&BaseMMO, 0, WideTy);
98+
auto WideLoad = B.buildLoad({DstRB, WideTy}, Base, *WideMMO);
99+
100+
if (WideTy.isScalar()) {
101+
B.buildTrunc(Dst, WideLoad);
102+
} else {
103+
SmallVector<Register, 4> MergeTyParts;
104+
auto Unmerge = B.buildUnmerge({DstRB, MergeTy}, WideLoad);
105+
106+
LLT DstTy = MRI.getType(Dst);
107+
unsigned NumElts = DstTy.getSizeInBits() / MergeTy.getSizeInBits();
108+
for (unsigned i = 0; i < NumElts; ++i) {
109+
MergeTyParts.push_back(Unmerge.getReg(i));
110+
}
111+
B.buildMergeLikeInstr(Dst, MergeTyParts);
112+
}
113+
MI.eraseFromParent();
114+
}
115+
39116
void RegBankLegalizeHelper::lower(MachineInstr &MI,
40117
const RegBankLLTMapping &Mapping,
41118
SmallSet<Register, 4> &WaterfallSgprs) {
@@ -114,6 +191,50 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
114191
MI.eraseFromParent();
115192
break;
116193
}
194+
case SplitLoad: {
195+
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
196+
unsigned Size = DstTy.getSizeInBits();
197+
// Even split to 128-bit loads
198+
if (Size > 128) {
199+
LLT B128;
200+
if (DstTy.isVector()) {
201+
LLT EltTy = DstTy.getElementType();
202+
B128 = LLT::fixed_vector(128 / EltTy.getSizeInBits(), EltTy);
203+
} else {
204+
B128 = LLT::scalar(128);
205+
}
206+
if (Size / 128 == 2)
207+
splitLoad(MI, {B128, B128});
208+
if (Size / 128 == 4)
209+
splitLoad(MI, {B128, B128, B128, B128});
210+
}
211+
// 64 and 32 bit load
212+
else if (DstTy == S96)
213+
splitLoad(MI, {S64, S32}, S32);
214+
else if (DstTy == V3S32)
215+
splitLoad(MI, {V2S32, S32}, S32);
216+
else if (DstTy == V6S16)
217+
splitLoad(MI, {V4S16, V2S16}, V2S16);
218+
else {
219+
MI.dump();
220+
llvm_unreachable("SplitLoad type not supported");
221+
}
222+
break;
223+
}
224+
case WidenLoad: {
225+
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
226+
if (DstTy == S96)
227+
widenLoad(MI, S128);
228+
else if (DstTy == V3S32)
229+
widenLoad(MI, V4S32, S32);
230+
else if (DstTy == V6S16)
231+
widenLoad(MI, V8S16, V2S16);
232+
else {
233+
MI.dump();
234+
llvm_unreachable("WidenLoad type not supported");
235+
}
236+
break;
237+
}
117238
}
118239

119240
// TODO: executeInWaterfallLoop(... WaterfallSgprs)
@@ -137,13 +258,73 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMapingApplyID ID) {
137258
case Sgpr64:
138259
case Vgpr64:
139260
return LLT::scalar(64);
140-
261+
case SgprP1:
262+
case VgprP1:
263+
return LLT::pointer(1, 64);
264+
case SgprP3:
265+
case VgprP3:
266+
return LLT::pointer(3, 32);
267+
case SgprP4:
268+
case VgprP4:
269+
return LLT::pointer(4, 64);
270+
case SgprP5:
271+
case VgprP5:
272+
return LLT::pointer(5, 32);
141273
case SgprV4S32:
142274
case VgprV4S32:
143275
case UniInVgprV4S32:
144276
return LLT::fixed_vector(4, 32);
145-
case VgprP1:
146-
return LLT::pointer(1, 64);
277+
default:
278+
return LLT();
279+
}
280+
}
281+
282+
LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMapingApplyID ID, LLT Ty) {
283+
switch (ID) {
284+
case SgprB32:
285+
case VgprB32:
286+
case UniInVgprB32:
287+
if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
288+
Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
289+
Ty == LLT::pointer(6, 32))
290+
return Ty;
291+
return LLT();
292+
case SgprB64:
293+
case VgprB64:
294+
case UniInVgprB64:
295+
if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
296+
Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
297+
Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64))
298+
return Ty;
299+
return LLT();
300+
case SgprB96:
301+
case VgprB96:
302+
case UniInVgprB96:
303+
if (Ty == LLT::scalar(96) || Ty == LLT::fixed_vector(3, 32) ||
304+
Ty == LLT::fixed_vector(6, 16))
305+
return Ty;
306+
return LLT();
307+
case SgprB128:
308+
case VgprB128:
309+
case UniInVgprB128:
310+
if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
311+
Ty == LLT::fixed_vector(2, 64))
312+
return Ty;
313+
return LLT();
314+
case SgprB256:
315+
case VgprB256:
316+
case UniInVgprB256:
317+
if (Ty == LLT::scalar(256) || Ty == LLT::fixed_vector(8, 32) ||
318+
Ty == LLT::fixed_vector(4, 64) || Ty == LLT::fixed_vector(16, 16))
319+
return Ty;
320+
return LLT();
321+
case SgprB512:
322+
case VgprB512:
323+
case UniInVgprB512:
324+
if (Ty == LLT::scalar(512) || Ty == LLT::fixed_vector(16, 32) ||
325+
Ty == LLT::fixed_vector(8, 64))
326+
return Ty;
327+
return LLT();
147328
default:
148329
return LLT();
149330
}
@@ -158,10 +339,26 @@ RegBankLegalizeHelper::getRBFromID(RegBankLLTMapingApplyID ID) {
158339
case Sgpr16:
159340
case Sgpr32:
160341
case Sgpr64:
342+
case SgprP1:
343+
case SgprP3:
344+
case SgprP4:
345+
case SgprP5:
161346
case SgprV4S32:
347+
case SgprB32:
348+
case SgprB64:
349+
case SgprB96:
350+
case SgprB128:
351+
case SgprB256:
352+
case SgprB512:
162353
case UniInVcc:
163354
case UniInVgprS32:
164355
case UniInVgprV4S32:
356+
case UniInVgprB32:
357+
case UniInVgprB64:
358+
case UniInVgprB96:
359+
case UniInVgprB128:
360+
case UniInVgprB256:
361+
case UniInVgprB512:
165362
case Sgpr32Trunc:
166363
case Sgpr32AExt:
167364
case Sgpr32AExtBoolInReg:
@@ -171,7 +368,16 @@ RegBankLegalizeHelper::getRBFromID(RegBankLLTMapingApplyID ID) {
171368
case Vgpr32:
172369
case Vgpr64:
173370
case VgprP1:
371+
case VgprP3:
372+
case VgprP4:
373+
case VgprP5:
174374
case VgprV4S32:
375+
case VgprB32:
376+
case VgprB64:
377+
case VgprB96:
378+
case VgprB128:
379+
case VgprB256:
380+
case VgprB512:
175381
return VgprRB;
176382

177383
default:
@@ -197,17 +403,42 @@ void RegBankLegalizeHelper::applyMappingDst(
197403
case Sgpr16:
198404
case Sgpr32:
199405
case Sgpr64:
406+
case SgprP1:
407+
case SgprP3:
408+
case SgprP4:
409+
case SgprP5:
200410
case SgprV4S32:
201411
case Vgpr32:
202412
case Vgpr64:
203413
case VgprP1:
414+
case VgprP3:
415+
case VgprP4:
416+
case VgprP5:
204417
case VgprV4S32: {
205418
assert(Ty == getTyFromID(MethodIDs[OpIdx]));
206419
assert(RB == getRBFromID(MethodIDs[OpIdx]));
207420
break;
208421
}
209422

210-
// uniform in vcc/vgpr: scalars and vectors
423+
// sgpr and vgpr B-types
424+
case SgprB32:
425+
case SgprB64:
426+
case SgprB96:
427+
case SgprB128:
428+
case SgprB256:
429+
case SgprB512:
430+
case VgprB32:
431+
case VgprB64:
432+
case VgprB96:
433+
case VgprB128:
434+
case VgprB256:
435+
case VgprB512: {
436+
assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
437+
assert(RB == getRBFromID(MethodIDs[OpIdx]));
438+
break;
439+
}
440+
441+
// uniform in vcc/vgpr: scalars, vectors and B-types
211442
case UniInVcc: {
212443
assert(Ty == S1);
213444
assert(RB == SgprRB);
@@ -227,6 +458,20 @@ void RegBankLegalizeHelper::applyMappingDst(
227458
buildReadAnyLane(B, Reg, NewVgprDst, RBI);
228459
break;
229460
}
461+
case UniInVgprB32:
462+
case UniInVgprB64:
463+
case UniInVgprB96:
464+
case UniInVgprB128:
465+
case UniInVgprB256:
466+
case UniInVgprB512: {
467+
assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
468+
assert(RB == SgprRB);
469+
470+
Register NewVgprDst = MRI.createVirtualRegister({VgprRB, Ty});
471+
Op.setReg(NewVgprDst);
472+
AMDGPU::buildReadAnyLane(B, Reg, NewVgprDst, RBI);
473+
break;
474+
}
230475

231476
// sgpr trunc
232477
case Sgpr32Trunc: {
@@ -278,16 +523,34 @@ void RegBankLegalizeHelper::applyMappingSrc(
278523
case Sgpr16:
279524
case Sgpr32:
280525
case Sgpr64:
526+
case SgprP1:
527+
case SgprP3:
528+
case SgprP4:
529+
case SgprP5:
281530
case SgprV4S32: {
282531
assert(Ty == getTyFromID(MethodIDs[i]));
283532
assert(RB == getRBFromID(MethodIDs[i]));
284533
break;
285534
}
535+
// sgpr B-types
536+
case SgprB32:
537+
case SgprB64:
538+
case SgprB96:
539+
case SgprB128:
540+
case SgprB256:
541+
case SgprB512: {
542+
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
543+
assert(RB == getRBFromID(MethodIDs[i]));
544+
break;
545+
}
286546

287547
// vgpr scalars, pointers and vectors
288548
case Vgpr32:
289549
case Vgpr64:
290550
case VgprP1:
551+
case VgprP3:
552+
case VgprP4:
553+
case VgprP5:
291554
case VgprV4S32: {
292555
assert(Ty == getTyFromID(MethodIDs[i]));
293556
if (RB != VgprRB) {
@@ -296,6 +559,20 @@ void RegBankLegalizeHelper::applyMappingSrc(
296559
}
297560
break;
298561
}
562+
// vgpr B-types
563+
case VgprB32:
564+
case VgprB64:
565+
case VgprB96:
566+
case VgprB128:
567+
case VgprB256:
568+
case VgprB512: {
569+
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
570+
if (RB != VgprRB) {
571+
auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg);
572+
Op.setReg(CopyToVgpr.getReg(0));
573+
}
574+
break;
575+
}
299576

300577
// sgpr and vgpr scalars with extend
301578
case Sgpr32AExt: {
@@ -368,7 +645,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
368645
// We accept all types that can fit in some register class.
369646
// Uniform G_PHIs have all sgpr registers.
370647
// Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
371-
if (Ty == LLT::scalar(32)) {
648+
if (Ty == LLT::scalar(32) || Ty == LLT::pointer(4, 64)) {
372649
return;
373650
}
374651

0 commit comments

Comments
 (0)