Skip to content

Commit 59e70ef

Browse files
AMDGPU/GlobalISel: RegBankLegalize rules for load
Add IDs for bit width that cover multiple LLTs: B32 B64 etc. "Predicate" wrapper class for bool predicate functions used to write pretty rules. Predicates can be combined using &&, || and !. Lowering for splitting and widening loads. Write rules for loads to not change existing mir tests from old regbankselect.
1 parent da35db7 commit 59e70ef

File tree

6 files changed

+927
-65
lines changed

6 files changed

+927
-65
lines changed

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp

Lines changed: 280 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,83 @@ void RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) {
3838
lower(MI, Mapping, WaterfallSgprs);
3939
}
4040

41+
void RegBankLegalizeHelper::splitLoad(MachineInstr &MI,
42+
ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
43+
MachineFunction &MF = B.getMF();
44+
assert(MI.getNumMemOperands() == 1);
45+
MachineMemOperand &BaseMMO = **MI.memoperands_begin();
46+
Register Dst = MI.getOperand(0).getReg();
47+
const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
48+
Register Base = MI.getOperand(1).getReg();
49+
LLT PtrTy = MRI.getType(Base);
50+
const RegisterBank *PtrRB = MRI.getRegBankOrNull(Base);
51+
LLT OffsetTy = LLT::scalar(PtrTy.getSizeInBits());
52+
SmallVector<Register, 4> LoadPartRegs;
53+
54+
unsigned ByteOffset = 0;
55+
for (LLT PartTy : LLTBreakdown) {
56+
Register BasePlusOffset;
57+
if (ByteOffset == 0) {
58+
BasePlusOffset = Base;
59+
} else {
60+
auto Offset = B.buildConstant({PtrRB, OffsetTy}, ByteOffset);
61+
BasePlusOffset = B.buildPtrAdd({PtrRB, PtrTy}, Base, Offset).getReg(0);
62+
}
63+
auto *OffsetMMO = MF.getMachineMemOperand(&BaseMMO, ByteOffset, PartTy);
64+
auto LoadPart = B.buildLoad({DstRB, PartTy}, BasePlusOffset, *OffsetMMO);
65+
LoadPartRegs.push_back(LoadPart.getReg(0));
66+
ByteOffset += PartTy.getSizeInBytes();
67+
}
68+
69+
if (!MergeTy.isValid()) {
70+
// Loads are of same size, concat or merge them together.
71+
B.buildMergeLikeInstr(Dst, LoadPartRegs);
72+
} else {
73+
// Loads are not all of same size, need to unmerge them to smaller pieces
74+
// of MergeTy type, then merge pieces to Dst.
75+
SmallVector<Register, 4> MergeTyParts;
76+
for (Register Reg : LoadPartRegs) {
77+
if (MRI.getType(Reg) == MergeTy) {
78+
MergeTyParts.push_back(Reg);
79+
} else {
80+
auto Unmerge = B.buildUnmerge({DstRB, MergeTy}, Reg);
81+
for (unsigned i = 0; i < Unmerge->getNumOperands() - 1; ++i)
82+
MergeTyParts.push_back(Unmerge.getReg(i));
83+
}
84+
}
85+
B.buildMergeLikeInstr(Dst, MergeTyParts);
86+
}
87+
MI.eraseFromParent();
88+
}
89+
90+
void RegBankLegalizeHelper::widenLoad(MachineInstr &MI, LLT WideTy,
91+
LLT MergeTy) {
92+
MachineFunction &MF = B.getMF();
93+
assert(MI.getNumMemOperands() == 1);
94+
MachineMemOperand &BaseMMO = **MI.memoperands_begin();
95+
Register Dst = MI.getOperand(0).getReg();
96+
const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
97+
Register Base = MI.getOperand(1).getReg();
98+
99+
MachineMemOperand *WideMMO = MF.getMachineMemOperand(&BaseMMO, 0, WideTy);
100+
auto WideLoad = B.buildLoad({DstRB, WideTy}, Base, *WideMMO);
101+
102+
if (WideTy.isScalar()) {
103+
B.buildTrunc(Dst, WideLoad);
104+
} else {
105+
SmallVector<Register, 4> MergeTyParts;
106+
auto Unmerge = B.buildUnmerge({DstRB, MergeTy}, WideLoad);
107+
108+
LLT DstTy = MRI.getType(Dst);
109+
unsigned NumElts = DstTy.getSizeInBits() / MergeTy.getSizeInBits();
110+
for (unsigned i = 0; i < NumElts; ++i) {
111+
MergeTyParts.push_back(Unmerge.getReg(i));
112+
}
113+
B.buildMergeLikeInstr(Dst, MergeTyParts);
114+
}
115+
MI.eraseFromParent();
116+
}
117+
41118
void RegBankLegalizeHelper::lower(MachineInstr &MI,
42119
const RegBankLLTMapping &Mapping,
43120
SmallSet<Register, 4> &WaterfallSgprs) {
@@ -116,6 +193,50 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
116193
MI.eraseFromParent();
117194
break;
118195
}
196+
case SplitLoad: {
197+
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
198+
unsigned Size = DstTy.getSizeInBits();
199+
// Even split to 128-bit loads
200+
if (Size > 128) {
201+
LLT B128;
202+
if (DstTy.isVector()) {
203+
LLT EltTy = DstTy.getElementType();
204+
B128 = LLT::fixed_vector(128 / EltTy.getSizeInBits(), EltTy);
205+
} else {
206+
B128 = LLT::scalar(128);
207+
}
208+
if (Size / 128 == 2)
209+
splitLoad(MI, {B128, B128});
210+
if (Size / 128 == 4)
211+
splitLoad(MI, {B128, B128, B128, B128});
212+
}
213+
// 64 and 32 bit load
214+
else if (DstTy == S96)
215+
splitLoad(MI, {S64, S32}, S32);
216+
else if (DstTy == V3S32)
217+
splitLoad(MI, {V2S32, S32}, S32);
218+
else if (DstTy == V6S16)
219+
splitLoad(MI, {V4S16, V2S16}, V2S16);
220+
else {
221+
LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
222+
llvm_unreachable("SplitLoad type not supported for MI");
223+
}
224+
break;
225+
}
226+
case WidenLoad: {
227+
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
228+
if (DstTy == S96)
229+
widenLoad(MI, S128);
230+
else if (DstTy == V3S32)
231+
widenLoad(MI, V4S32, S32);
232+
else if (DstTy == V6S16)
233+
widenLoad(MI, V8S16, V2S16);
234+
else {
235+
LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
236+
llvm_unreachable("WidenLoad type not supported for MI");
237+
}
238+
break;
239+
}
119240
}
120241

121242
// TODO: executeInWaterfallLoop(... WaterfallSgprs)
@@ -139,12 +260,73 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMapingApplyID ID) {
139260
case Sgpr64:
140261
case Vgpr64:
141262
return LLT::scalar(64);
263+
case SgprP1:
264+
case VgprP1:
265+
return LLT::pointer(1, 64);
266+
case SgprP3:
267+
case VgprP3:
268+
return LLT::pointer(3, 32);
269+
case SgprP4:
270+
case VgprP4:
271+
return LLT::pointer(4, 64);
272+
case SgprP5:
273+
case VgprP5:
274+
return LLT::pointer(5, 32);
142275
case SgprV4S32:
143276
case VgprV4S32:
144277
case UniInVgprV4S32:
145278
return LLT::fixed_vector(4, 32);
146-
case VgprP1:
147-
return LLT::pointer(1, 64);
279+
default:
280+
return LLT();
281+
}
282+
}
283+
284+
LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMapingApplyID ID, LLT Ty) {
285+
switch (ID) {
286+
case SgprB32:
287+
case VgprB32:
288+
case UniInVgprB32:
289+
if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
290+
Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
291+
Ty == LLT::pointer(6, 32))
292+
return Ty;
293+
return LLT();
294+
case SgprB64:
295+
case VgprB64:
296+
case UniInVgprB64:
297+
if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
298+
Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
299+
Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64))
300+
return Ty;
301+
return LLT();
302+
case SgprB96:
303+
case VgprB96:
304+
case UniInVgprB96:
305+
if (Ty == LLT::scalar(96) || Ty == LLT::fixed_vector(3, 32) ||
306+
Ty == LLT::fixed_vector(6, 16))
307+
return Ty;
308+
return LLT();
309+
case SgprB128:
310+
case VgprB128:
311+
case UniInVgprB128:
312+
if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
313+
Ty == LLT::fixed_vector(2, 64))
314+
return Ty;
315+
return LLT();
316+
case SgprB256:
317+
case VgprB256:
318+
case UniInVgprB256:
319+
if (Ty == LLT::scalar(256) || Ty == LLT::fixed_vector(8, 32) ||
320+
Ty == LLT::fixed_vector(4, 64) || Ty == LLT::fixed_vector(16, 16))
321+
return Ty;
322+
return LLT();
323+
case SgprB512:
324+
case VgprB512:
325+
case UniInVgprB512:
326+
if (Ty == LLT::scalar(512) || Ty == LLT::fixed_vector(16, 32) ||
327+
Ty == LLT::fixed_vector(8, 64))
328+
return Ty;
329+
return LLT();
148330
default:
149331
return LLT();
150332
}
@@ -158,10 +340,26 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMapingApplyID ID) {
158340
case Sgpr16:
159341
case Sgpr32:
160342
case Sgpr64:
343+
case SgprP1:
344+
case SgprP3:
345+
case SgprP4:
346+
case SgprP5:
161347
case SgprV4S32:
348+
case SgprB32:
349+
case SgprB64:
350+
case SgprB96:
351+
case SgprB128:
352+
case SgprB256:
353+
case SgprB512:
162354
case UniInVcc:
163355
case UniInVgprS32:
164356
case UniInVgprV4S32:
357+
case UniInVgprB32:
358+
case UniInVgprB64:
359+
case UniInVgprB96:
360+
case UniInVgprB128:
361+
case UniInVgprB256:
362+
case UniInVgprB512:
165363
case Sgpr32Trunc:
166364
case Sgpr32AExt:
167365
case Sgpr32AExtBoolInReg:
@@ -170,7 +368,16 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMapingApplyID ID) {
170368
case Vgpr32:
171369
case Vgpr64:
172370
case VgprP1:
371+
case VgprP3:
372+
case VgprP4:
373+
case VgprP5:
173374
case VgprV4S32:
375+
case VgprB32:
376+
case VgprB64:
377+
case VgprB96:
378+
case VgprB128:
379+
case VgprB256:
380+
case VgprB512:
174381
return VgprRB;
175382
default:
176383
return nullptr;
@@ -195,16 +402,40 @@ void RegBankLegalizeHelper::applyMappingDst(
195402
case Sgpr16:
196403
case Sgpr32:
197404
case Sgpr64:
405+
case SgprP1:
406+
case SgprP3:
407+
case SgprP4:
408+
case SgprP5:
198409
case SgprV4S32:
199410
case Vgpr32:
200411
case Vgpr64:
201412
case VgprP1:
413+
case VgprP3:
414+
case VgprP4:
415+
case VgprP5:
202416
case VgprV4S32: {
203417
assert(Ty == getTyFromID(MethodIDs[OpIdx]));
204418
assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
205419
break;
206420
}
207-
// uniform in vcc/vgpr: scalars and vectors
421+
// sgpr and vgpr B-types
422+
case SgprB32:
423+
case SgprB64:
424+
case SgprB96:
425+
case SgprB128:
426+
case SgprB256:
427+
case SgprB512:
428+
case VgprB32:
429+
case VgprB64:
430+
case VgprB96:
431+
case VgprB128:
432+
case VgprB256:
433+
case VgprB512: {
434+
assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
435+
assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
436+
break;
437+
}
438+
// uniform in vcc/vgpr: scalars, vectors and B-types
208439
case UniInVcc: {
209440
assert(Ty == S1);
210441
assert(RB == SgprRB);
@@ -223,6 +454,19 @@ void RegBankLegalizeHelper::applyMappingDst(
223454
buildReadAnyLane(B, Reg, NewVgprDst, RBI);
224455
break;
225456
}
457+
case UniInVgprB32:
458+
case UniInVgprB64:
459+
case UniInVgprB96:
460+
case UniInVgprB128:
461+
case UniInVgprB256:
462+
case UniInVgprB512: {
463+
assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
464+
assert(RB == SgprRB);
465+
Register NewVgprDst = MRI.createVirtualRegister({VgprRB, Ty});
466+
Op.setReg(NewVgprDst);
467+
AMDGPU::buildReadAnyLane(B, Reg, NewVgprDst, RBI);
468+
break;
469+
}
226470
// sgpr trunc
227471
case Sgpr32Trunc: {
228472
assert(Ty.getSizeInBits() < 32);
@@ -270,15 +514,33 @@ void RegBankLegalizeHelper::applyMappingSrc(
270514
case Sgpr16:
271515
case Sgpr32:
272516
case Sgpr64:
517+
case SgprP1:
518+
case SgprP3:
519+
case SgprP4:
520+
case SgprP5:
273521
case SgprV4S32: {
274522
assert(Ty == getTyFromID(MethodIDs[i]));
275523
assert(RB == getRegBankFromID(MethodIDs[i]));
276524
break;
277525
}
526+
// sgpr B-types
527+
case SgprB32:
528+
case SgprB64:
529+
case SgprB96:
530+
case SgprB128:
531+
case SgprB256:
532+
case SgprB512: {
533+
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
534+
assert(RB == getRegBankFromID(MethodIDs[i]));
535+
break;
536+
}
278537
// vgpr scalars, pointers and vectors
279538
case Vgpr32:
280539
case Vgpr64:
281540
case VgprP1:
541+
case VgprP3:
542+
case VgprP4:
543+
case VgprP5:
282544
case VgprV4S32: {
283545
assert(Ty == getTyFromID(MethodIDs[i]));
284546
if (RB != VgprRB) {
@@ -287,6 +549,20 @@ void RegBankLegalizeHelper::applyMappingSrc(
287549
}
288550
break;
289551
}
552+
// vgpr B-types
553+
case VgprB32:
554+
case VgprB64:
555+
case VgprB96:
556+
case VgprB128:
557+
case VgprB256:
558+
case VgprB512: {
559+
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
560+
if (RB != VgprRB) {
561+
auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg);
562+
Op.setReg(CopyToVgpr.getReg(0));
563+
}
564+
break;
565+
}
290566
// sgpr and vgpr scalars with extend
291567
case Sgpr32AExt: {
292568
// Note: this ext allows S1, and it is meant to be combined away.
@@ -359,7 +635,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
359635
// We accept all types that can fit in some register class.
360636
// Uniform G_PHIs have all sgpr registers.
361637
// Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
362-
if (Ty == LLT::scalar(32)) {
638+
if (Ty == LLT::scalar(32) || Ty == LLT::pointer(4, 64)) {
363639
return;
364640
}
365641

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ class RegBankLegalizeHelper {
9090
SmallSet<Register, 4> &SgprOperandRegs);
9191

9292
LLT getTyFromID(RegBankLLTMapingApplyID ID);
93+
LLT getBTyFromID(RegBankLLTMapingApplyID ID, LLT Ty);
9394

9495
const RegisterBank *getRegBankFromID(RegBankLLTMapingApplyID ID);
9596

@@ -102,6 +103,10 @@ class RegBankLegalizeHelper {
102103
const SmallVectorImpl<RegBankLLTMapingApplyID> &MethodIDs,
103104
SmallSet<Register, 4> &SgprWaterfallOperandRegs);
104105

106+
void splitLoad(MachineInstr &MI, ArrayRef<LLT> LLTBreakdown,
107+
LLT MergeTy = LLT());
108+
void widenLoad(MachineInstr &MI, LLT WideTy, LLT MergeTy = LLT());
109+
105110
void lower(MachineInstr &MI, const RegBankLLTMapping &Mapping,
106111
SmallSet<Register, 4> &SgprWaterfallOperandRegs);
107112
};

0 commit comments

Comments
 (0)