@@ -38,6 +38,83 @@ void RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) {
38
38
lower (MI, Mapping, WaterfallSgprs);
39
39
}
40
40
41
+ void RegBankLegalizeHelper::splitLoad (MachineInstr &MI,
42
+ ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
43
+ MachineFunction &MF = B.getMF ();
44
+ assert (MI.getNumMemOperands () == 1 );
45
+ MachineMemOperand &BaseMMO = **MI.memoperands_begin ();
46
+ Register Dst = MI.getOperand (0 ).getReg ();
47
+ const RegisterBank *DstRB = MRI.getRegBankOrNull (Dst);
48
+ Register Base = MI.getOperand (1 ).getReg ();
49
+ LLT PtrTy = MRI.getType (Base);
50
+ const RegisterBank *PtrRB = MRI.getRegBankOrNull (Base);
51
+ LLT OffsetTy = LLT::scalar (PtrTy.getSizeInBits ());
52
+ SmallVector<Register, 4 > LoadPartRegs;
53
+
54
+ unsigned ByteOffset = 0 ;
55
+ for (LLT PartTy : LLTBreakdown) {
56
+ Register BasePlusOffset;
57
+ if (ByteOffset == 0 ) {
58
+ BasePlusOffset = Base;
59
+ } else {
60
+ auto Offset = B.buildConstant ({PtrRB, OffsetTy}, ByteOffset);
61
+ BasePlusOffset = B.buildPtrAdd ({PtrRB, PtrTy}, Base, Offset).getReg (0 );
62
+ }
63
+ auto *OffsetMMO = MF.getMachineMemOperand (&BaseMMO, ByteOffset, PartTy);
64
+ auto LoadPart = B.buildLoad ({DstRB, PartTy}, BasePlusOffset, *OffsetMMO);
65
+ LoadPartRegs.push_back (LoadPart.getReg (0 ));
66
+ ByteOffset += PartTy.getSizeInBytes ();
67
+ }
68
+
69
+ if (!MergeTy.isValid ()) {
70
+ // Loads are of same size, concat or merge them together.
71
+ B.buildMergeLikeInstr (Dst, LoadPartRegs);
72
+ } else {
73
+ // Loads are not all of same size, need to unmerge them to smaller pieces
74
+ // of MergeTy type, then merge pieces to Dst.
75
+ SmallVector<Register, 4 > MergeTyParts;
76
+ for (Register Reg : LoadPartRegs) {
77
+ if (MRI.getType (Reg) == MergeTy) {
78
+ MergeTyParts.push_back (Reg);
79
+ } else {
80
+ auto Unmerge = B.buildUnmerge ({DstRB, MergeTy}, Reg);
81
+ for (unsigned i = 0 ; i < Unmerge->getNumOperands () - 1 ; ++i)
82
+ MergeTyParts.push_back (Unmerge.getReg (i));
83
+ }
84
+ }
85
+ B.buildMergeLikeInstr (Dst, MergeTyParts);
86
+ }
87
+ MI.eraseFromParent ();
88
+ }
89
+
90
+ void RegBankLegalizeHelper::widenLoad (MachineInstr &MI, LLT WideTy,
91
+ LLT MergeTy) {
92
+ MachineFunction &MF = B.getMF ();
93
+ assert (MI.getNumMemOperands () == 1 );
94
+ MachineMemOperand &BaseMMO = **MI.memoperands_begin ();
95
+ Register Dst = MI.getOperand (0 ).getReg ();
96
+ const RegisterBank *DstRB = MRI.getRegBankOrNull (Dst);
97
+ Register Base = MI.getOperand (1 ).getReg ();
98
+
99
+ MachineMemOperand *WideMMO = MF.getMachineMemOperand (&BaseMMO, 0 , WideTy);
100
+ auto WideLoad = B.buildLoad ({DstRB, WideTy}, Base, *WideMMO);
101
+
102
+ if (WideTy.isScalar ()) {
103
+ B.buildTrunc (Dst, WideLoad);
104
+ } else {
105
+ SmallVector<Register, 4 > MergeTyParts;
106
+ auto Unmerge = B.buildUnmerge ({DstRB, MergeTy}, WideLoad);
107
+
108
+ LLT DstTy = MRI.getType (Dst);
109
+ unsigned NumElts = DstTy.getSizeInBits () / MergeTy.getSizeInBits ();
110
+ for (unsigned i = 0 ; i < NumElts; ++i) {
111
+ MergeTyParts.push_back (Unmerge.getReg (i));
112
+ }
113
+ B.buildMergeLikeInstr (Dst, MergeTyParts);
114
+ }
115
+ MI.eraseFromParent ();
116
+ }
117
+
41
118
void RegBankLegalizeHelper::lower (MachineInstr &MI,
42
119
const RegBankLLTMapping &Mapping,
43
120
SmallSet<Register, 4 > &WaterfallSgprs) {
@@ -116,6 +193,50 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
116
193
MI.eraseFromParent ();
117
194
break ;
118
195
}
196
+ case SplitLoad: {
197
+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
198
+ unsigned Size = DstTy.getSizeInBits ();
199
+ // Even split to 128-bit loads
200
+ if (Size > 128 ) {
201
+ LLT B128;
202
+ if (DstTy.isVector ()) {
203
+ LLT EltTy = DstTy.getElementType ();
204
+ B128 = LLT::fixed_vector (128 / EltTy.getSizeInBits (), EltTy);
205
+ } else {
206
+ B128 = LLT::scalar (128 );
207
+ }
208
+ if (Size / 128 == 2 )
209
+ splitLoad (MI, {B128, B128});
210
+ if (Size / 128 == 4 )
211
+ splitLoad (MI, {B128, B128, B128, B128});
212
+ }
213
+ // 64 and 32 bit load
214
+ else if (DstTy == S96)
215
+ splitLoad (MI, {S64, S32}, S32);
216
+ else if (DstTy == V3S32)
217
+ splitLoad (MI, {V2S32, S32}, S32);
218
+ else if (DstTy == V6S16)
219
+ splitLoad (MI, {V4S16, V2S16}, V2S16);
220
+ else {
221
+ LLVM_DEBUG (dbgs () << " MI: " ; MI.dump (););
222
+ llvm_unreachable (" SplitLoad type not supported for MI" );
223
+ }
224
+ break ;
225
+ }
226
+ case WidenLoad: {
227
+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
228
+ if (DstTy == S96)
229
+ widenLoad (MI, S128);
230
+ else if (DstTy == V3S32)
231
+ widenLoad (MI, V4S32, S32);
232
+ else if (DstTy == V6S16)
233
+ widenLoad (MI, V8S16, V2S16);
234
+ else {
235
+ LLVM_DEBUG (dbgs () << " MI: " ; MI.dump (););
236
+ llvm_unreachable (" WidenLoad type not supported for MI" );
237
+ }
238
+ break ;
239
+ }
119
240
}
120
241
121
242
// TODO: executeInWaterfallLoop(... WaterfallSgprs)
@@ -139,12 +260,73 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMapingApplyID ID) {
139
260
case Sgpr64:
140
261
case Vgpr64:
141
262
return LLT::scalar (64 );
263
+ case SgprP1:
264
+ case VgprP1:
265
+ return LLT::pointer (1 , 64 );
266
+ case SgprP3:
267
+ case VgprP3:
268
+ return LLT::pointer (3 , 32 );
269
+ case SgprP4:
270
+ case VgprP4:
271
+ return LLT::pointer (4 , 64 );
272
+ case SgprP5:
273
+ case VgprP5:
274
+ return LLT::pointer (5 , 32 );
142
275
case SgprV4S32:
143
276
case VgprV4S32:
144
277
case UniInVgprV4S32:
145
278
return LLT::fixed_vector (4 , 32 );
146
- case VgprP1:
147
- return LLT::pointer (1 , 64 );
279
+ default :
280
+ return LLT ();
281
+ }
282
+ }
283
+
284
+ LLT RegBankLegalizeHelper::getBTyFromID (RegBankLLTMapingApplyID ID, LLT Ty) {
285
+ switch (ID) {
286
+ case SgprB32:
287
+ case VgprB32:
288
+ case UniInVgprB32:
289
+ if (Ty == LLT::scalar (32 ) || Ty == LLT::fixed_vector (2 , 16 ) ||
290
+ Ty == LLT::pointer (3 , 32 ) || Ty == LLT::pointer (5 , 32 ) ||
291
+ Ty == LLT::pointer (6 , 32 ))
292
+ return Ty;
293
+ return LLT ();
294
+ case SgprB64:
295
+ case VgprB64:
296
+ case UniInVgprB64:
297
+ if (Ty == LLT::scalar (64 ) || Ty == LLT::fixed_vector (2 , 32 ) ||
298
+ Ty == LLT::fixed_vector (4 , 16 ) || Ty == LLT::pointer (0 , 64 ) ||
299
+ Ty == LLT::pointer (1 , 64 ) || Ty == LLT::pointer (4 , 64 ))
300
+ return Ty;
301
+ return LLT ();
302
+ case SgprB96:
303
+ case VgprB96:
304
+ case UniInVgprB96:
305
+ if (Ty == LLT::scalar (96 ) || Ty == LLT::fixed_vector (3 , 32 ) ||
306
+ Ty == LLT::fixed_vector (6 , 16 ))
307
+ return Ty;
308
+ return LLT ();
309
+ case SgprB128:
310
+ case VgprB128:
311
+ case UniInVgprB128:
312
+ if (Ty == LLT::scalar (128 ) || Ty == LLT::fixed_vector (4 , 32 ) ||
313
+ Ty == LLT::fixed_vector (2 , 64 ))
314
+ return Ty;
315
+ return LLT ();
316
+ case SgprB256:
317
+ case VgprB256:
318
+ case UniInVgprB256:
319
+ if (Ty == LLT::scalar (256 ) || Ty == LLT::fixed_vector (8 , 32 ) ||
320
+ Ty == LLT::fixed_vector (4 , 64 ) || Ty == LLT::fixed_vector (16 , 16 ))
321
+ return Ty;
322
+ return LLT ();
323
+ case SgprB512:
324
+ case VgprB512:
325
+ case UniInVgprB512:
326
+ if (Ty == LLT::scalar (512 ) || Ty == LLT::fixed_vector (16 , 32 ) ||
327
+ Ty == LLT::fixed_vector (8 , 64 ))
328
+ return Ty;
329
+ return LLT ();
148
330
default :
149
331
return LLT ();
150
332
}
@@ -158,10 +340,26 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMapingApplyID ID) {
158
340
case Sgpr16:
159
341
case Sgpr32:
160
342
case Sgpr64:
343
+ case SgprP1:
344
+ case SgprP3:
345
+ case SgprP4:
346
+ case SgprP5:
161
347
case SgprV4S32:
348
+ case SgprB32:
349
+ case SgprB64:
350
+ case SgprB96:
351
+ case SgprB128:
352
+ case SgprB256:
353
+ case SgprB512:
162
354
case UniInVcc:
163
355
case UniInVgprS32:
164
356
case UniInVgprV4S32:
357
+ case UniInVgprB32:
358
+ case UniInVgprB64:
359
+ case UniInVgprB96:
360
+ case UniInVgprB128:
361
+ case UniInVgprB256:
362
+ case UniInVgprB512:
165
363
case Sgpr32Trunc:
166
364
case Sgpr32AExt:
167
365
case Sgpr32AExtBoolInReg:
@@ -170,7 +368,16 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMapingApplyID ID) {
170
368
case Vgpr32:
171
369
case Vgpr64:
172
370
case VgprP1:
371
+ case VgprP3:
372
+ case VgprP4:
373
+ case VgprP5:
173
374
case VgprV4S32:
375
+ case VgprB32:
376
+ case VgprB64:
377
+ case VgprB96:
378
+ case VgprB128:
379
+ case VgprB256:
380
+ case VgprB512:
174
381
return VgprRB;
175
382
default :
176
383
return nullptr ;
@@ -195,16 +402,40 @@ void RegBankLegalizeHelper::applyMappingDst(
195
402
case Sgpr16:
196
403
case Sgpr32:
197
404
case Sgpr64:
405
+ case SgprP1:
406
+ case SgprP3:
407
+ case SgprP4:
408
+ case SgprP5:
198
409
case SgprV4S32:
199
410
case Vgpr32:
200
411
case Vgpr64:
201
412
case VgprP1:
413
+ case VgprP3:
414
+ case VgprP4:
415
+ case VgprP5:
202
416
case VgprV4S32: {
203
417
assert (Ty == getTyFromID (MethodIDs[OpIdx]));
204
418
assert (RB == getRegBankFromID (MethodIDs[OpIdx]));
205
419
break ;
206
420
}
207
- // uniform in vcc/vgpr: scalars and vectors
421
+ // sgpr and vgpr B-types
422
+ case SgprB32:
423
+ case SgprB64:
424
+ case SgprB96:
425
+ case SgprB128:
426
+ case SgprB256:
427
+ case SgprB512:
428
+ case VgprB32:
429
+ case VgprB64:
430
+ case VgprB96:
431
+ case VgprB128:
432
+ case VgprB256:
433
+ case VgprB512: {
434
+ assert (Ty == getBTyFromID (MethodIDs[OpIdx], Ty));
435
+ assert (RB == getRegBankFromID (MethodIDs[OpIdx]));
436
+ break ;
437
+ }
438
+ // uniform in vcc/vgpr: scalars, vectors and B-types
208
439
case UniInVcc: {
209
440
assert (Ty == S1);
210
441
assert (RB == SgprRB);
@@ -223,6 +454,19 @@ void RegBankLegalizeHelper::applyMappingDst(
223
454
buildReadAnyLane (B, Reg, NewVgprDst, RBI);
224
455
break ;
225
456
}
457
+ case UniInVgprB32:
458
+ case UniInVgprB64:
459
+ case UniInVgprB96:
460
+ case UniInVgprB128:
461
+ case UniInVgprB256:
462
+ case UniInVgprB512: {
463
+ assert (Ty == getBTyFromID (MethodIDs[OpIdx], Ty));
464
+ assert (RB == SgprRB);
465
+ Register NewVgprDst = MRI.createVirtualRegister ({VgprRB, Ty});
466
+ Op.setReg (NewVgprDst);
467
+ AMDGPU::buildReadAnyLane (B, Reg, NewVgprDst, RBI);
468
+ break ;
469
+ }
226
470
// sgpr trunc
227
471
case Sgpr32Trunc: {
228
472
assert (Ty.getSizeInBits () < 32 );
@@ -270,15 +514,33 @@ void RegBankLegalizeHelper::applyMappingSrc(
270
514
case Sgpr16:
271
515
case Sgpr32:
272
516
case Sgpr64:
517
+ case SgprP1:
518
+ case SgprP3:
519
+ case SgprP4:
520
+ case SgprP5:
273
521
case SgprV4S32: {
274
522
assert (Ty == getTyFromID (MethodIDs[i]));
275
523
assert (RB == getRegBankFromID (MethodIDs[i]));
276
524
break ;
277
525
}
526
+ // sgpr B-types
527
+ case SgprB32:
528
+ case SgprB64:
529
+ case SgprB96:
530
+ case SgprB128:
531
+ case SgprB256:
532
+ case SgprB512: {
533
+ assert (Ty == getBTyFromID (MethodIDs[i], Ty));
534
+ assert (RB == getRegBankFromID (MethodIDs[i]));
535
+ break ;
536
+ }
278
537
// vgpr scalars, pointers and vectors
279
538
case Vgpr32:
280
539
case Vgpr64:
281
540
case VgprP1:
541
+ case VgprP3:
542
+ case VgprP4:
543
+ case VgprP5:
282
544
case VgprV4S32: {
283
545
assert (Ty == getTyFromID (MethodIDs[i]));
284
546
if (RB != VgprRB) {
@@ -287,6 +549,20 @@ void RegBankLegalizeHelper::applyMappingSrc(
287
549
}
288
550
break ;
289
551
}
552
+ // vgpr B-types
553
+ case VgprB32:
554
+ case VgprB64:
555
+ case VgprB96:
556
+ case VgprB128:
557
+ case VgprB256:
558
+ case VgprB512: {
559
+ assert (Ty == getBTyFromID (MethodIDs[i], Ty));
560
+ if (RB != VgprRB) {
561
+ auto CopyToVgpr = B.buildCopy ({VgprRB, Ty}, Reg);
562
+ Op.setReg (CopyToVgpr.getReg (0 ));
563
+ }
564
+ break ;
565
+ }
290
566
// sgpr and vgpr scalars with extend
291
567
case Sgpr32AExt: {
292
568
// Note: this ext allows S1, and it is meant to be combined away.
@@ -359,7 +635,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
359
635
// We accept all types that can fit in some register class.
360
636
// Uniform G_PHIs have all sgpr registers.
361
637
// Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
362
- if (Ty == LLT::scalar (32 )) {
638
+ if (Ty == LLT::scalar (32 ) || Ty == LLT::pointer ( 4 , 64 ) ) {
363
639
return ;
364
640
}
365
641
0 commit comments