@@ -37,6 +37,97 @@ bool RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) {
37
37
return true ;
38
38
}
39
39
40
+ void RegBankLegalizeHelper::splitLoad (MachineInstr &MI,
41
+ ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
42
+ MachineFunction &MF = B.getMF ();
43
+ assert (MI.getNumMemOperands () == 1 );
44
+ MachineMemOperand &BaseMMO = **MI.memoperands_begin ();
45
+ Register Dst = MI.getOperand (0 ).getReg ();
46
+ const RegisterBank *DstRB = MRI.getRegBankOrNull (Dst);
47
+ Register BasePtrReg = MI.getOperand (1 ).getReg ();
48
+ LLT PtrTy = MRI.getType (BasePtrReg);
49
+ const RegisterBank *PtrRB = MRI.getRegBankOrNull (BasePtrReg);
50
+ LLT OffsetTy = LLT::scalar (PtrTy.getSizeInBits ());
51
+ SmallVector<Register, 4 > LoadPartRegs;
52
+
53
+ unsigned ByteOffset = 0 ;
54
+ for (LLT PartTy : LLTBreakdown) {
55
+ Register BasePtrPlusOffsetReg;
56
+ if (ByteOffset == 0 ) {
57
+ BasePtrPlusOffsetReg = BasePtrReg;
58
+ } else {
59
+ BasePtrPlusOffsetReg = MRI.createVirtualRegister ({PtrRB, PtrTy});
60
+ Register OffsetReg = MRI.createVirtualRegister ({PtrRB, OffsetTy});
61
+ B.buildConstant (OffsetReg, ByteOffset);
62
+ B.buildPtrAdd (BasePtrPlusOffsetReg, BasePtrReg, OffsetReg);
63
+ }
64
+ MachineMemOperand *BasePtrPlusOffsetMMO =
65
+ MF.getMachineMemOperand (&BaseMMO, ByteOffset, PartTy);
66
+ Register PartLoad = MRI.createVirtualRegister ({DstRB, PartTy});
67
+ B.buildLoad (PartLoad, BasePtrPlusOffsetReg, *BasePtrPlusOffsetMMO);
68
+ LoadPartRegs.push_back (PartLoad);
69
+ ByteOffset += PartTy.getSizeInBytes ();
70
+ }
71
+
72
+ if (!MergeTy.isValid ()) {
73
+ // Loads are of same size, concat or merge them together.
74
+ B.buildMergeLikeInstr (Dst, LoadPartRegs);
75
+ } else {
76
+ // Load(s) are not all of same size, need to unmerge them to smaller pieces
77
+ // of MergeTy type, then merge them all together in Dst.
78
+ SmallVector<Register, 4 > MergeTyParts;
79
+ for (Register Reg : LoadPartRegs) {
80
+ if (MRI.getType (Reg) == MergeTy) {
81
+ MergeTyParts.push_back (Reg);
82
+ } else {
83
+ auto Unmerge = B.buildUnmerge (MergeTy, Reg);
84
+ for (unsigned i = 0 ; i < Unmerge->getNumOperands () - 1 ; ++i) {
85
+ Register UnmergeReg = Unmerge->getOperand (i).getReg ();
86
+ MRI.setRegBank (UnmergeReg, *DstRB);
87
+ MergeTyParts.push_back (UnmergeReg);
88
+ }
89
+ }
90
+ }
91
+ B.buildMergeLikeInstr (Dst, MergeTyParts);
92
+ }
93
+ MI.eraseFromParent ();
94
+ }
95
+
96
+ void RegBankLegalizeHelper::widenLoad (MachineInstr &MI, LLT WideTy,
97
+ LLT MergeTy) {
98
+ MachineFunction &MF = B.getMF ();
99
+ assert (MI.getNumMemOperands () == 1 );
100
+ MachineMemOperand &BaseMMO = **MI.memoperands_begin ();
101
+ Register Dst = MI.getOperand (0 ).getReg ();
102
+ const RegisterBank *DstRB = MRI.getRegBankOrNull (Dst);
103
+ Register BasePtrReg = MI.getOperand (1 ).getReg ();
104
+
105
+ Register BasePtrPlusOffsetReg;
106
+ BasePtrPlusOffsetReg = BasePtrReg;
107
+
108
+ MachineMemOperand *BasePtrPlusOffsetMMO =
109
+ MF.getMachineMemOperand (&BaseMMO, 0 , WideTy);
110
+ Register WideLoad = MRI.createVirtualRegister ({DstRB, WideTy});
111
+ B.buildLoad (WideLoad, BasePtrPlusOffsetReg, *BasePtrPlusOffsetMMO);
112
+
113
+ if (WideTy.isScalar ()) {
114
+ B.buildTrunc (Dst, WideLoad);
115
+ } else {
116
+ SmallVector<Register, 4 > MergeTyParts;
117
+ unsigned NumEltsMerge =
118
+ MRI.getType (Dst).getSizeInBits () / MergeTy.getSizeInBits ();
119
+ auto Unmerge = B.buildUnmerge (MergeTy, WideLoad);
120
+ for (unsigned i = 0 ; i < Unmerge->getNumOperands () - 1 ; ++i) {
121
+ Register UnmergeReg = Unmerge->getOperand (i).getReg ();
122
+ MRI.setRegBank (UnmergeReg, *DstRB);
123
+ if (i < NumEltsMerge)
124
+ MergeTyParts.push_back (UnmergeReg);
125
+ }
126
+ B.buildMergeLikeInstr (Dst, MergeTyParts);
127
+ }
128
+ MI.eraseFromParent ();
129
+ }
130
+
40
131
void RegBankLegalizeHelper::lower (MachineInstr &MI,
41
132
const RegBankLLTMapping &Mapping,
42
133
SmallSet<Register, 4 > &WaterfallSGPRs) {
@@ -119,6 +210,50 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
119
210
MI.eraseFromParent ();
120
211
break ;
121
212
}
213
+ case SplitLoad: {
214
+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
215
+ unsigned Size = DstTy.getSizeInBits ();
216
+ // Even split to 128-bit loads
217
+ if (Size > 128 ) {
218
+ LLT B128;
219
+ if (DstTy.isVector ()) {
220
+ LLT EltTy = DstTy.getElementType ();
221
+ B128 = LLT::fixed_vector (128 / EltTy.getSizeInBits (), EltTy);
222
+ } else {
223
+ B128 = LLT::scalar (128 );
224
+ }
225
+ if (Size / 128 == 2 )
226
+ splitLoad (MI, {B128, B128});
227
+ if (Size / 128 == 4 )
228
+ splitLoad (MI, {B128, B128, B128, B128});
229
+ }
230
+ // 64 and 32 bit load
231
+ else if (DstTy == S96)
232
+ splitLoad (MI, {S64, S32}, S32);
233
+ else if (DstTy == V3S32)
234
+ splitLoad (MI, {V2S32, S32}, S32);
235
+ else if (DstTy == V6S16)
236
+ splitLoad (MI, {V4S16, V2S16}, V2S16);
237
+ else {
238
+ MI.dump ();
239
+ llvm_unreachable (" SplitLoad type not supported" );
240
+ }
241
+ break ;
242
+ }
243
+ case WidenLoad: {
244
+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
245
+ if (DstTy == S96)
246
+ widenLoad (MI, S128);
247
+ else if (DstTy == V3S32)
248
+ widenLoad (MI, V4S32, S32);
249
+ else if (DstTy == V6S16)
250
+ widenLoad (MI, V8S16, V2S16);
251
+ else {
252
+ MI.dump ();
253
+ llvm_unreachable (" WidenLoad type not supported" );
254
+ }
255
+ break ;
256
+ }
122
257
}
123
258
124
259
// TODO: executeInWaterfallLoop(... WaterfallSGPRs)
@@ -142,13 +277,73 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMapingApplyID ID) {
142
277
case Sgpr64:
143
278
case Vgpr64:
144
279
return LLT::scalar (64 );
145
-
280
+ case SgprP1:
281
+ case VgprP1:
282
+ return LLT::pointer (1 , 64 );
283
+ case SgprP3:
284
+ case VgprP3:
285
+ return LLT::pointer (3 , 32 );
286
+ case SgprP4:
287
+ case VgprP4:
288
+ return LLT::pointer (4 , 64 );
289
+ case SgprP5:
290
+ case VgprP5:
291
+ return LLT::pointer (5 , 32 );
146
292
case SgprV4S32:
147
293
case VgprV4S32:
148
294
case UniInVgprV4S32:
149
295
return LLT::fixed_vector (4 , 32 );
150
- case VgprP1:
151
- return LLT::pointer (1 , 64 );
296
+ default :
297
+ return LLT ();
298
+ }
299
+ }
300
+
301
+ LLT RegBankLegalizeHelper::getBTyFromID (RegBankLLTMapingApplyID ID, LLT Ty) {
302
+ switch (ID) {
303
+ case SgprB32:
304
+ case VgprB32:
305
+ case UniInVgprB32:
306
+ if (Ty == LLT::scalar (32 ) || Ty == LLT::fixed_vector (2 , 16 ) ||
307
+ Ty == LLT::pointer (3 , 32 ) || Ty == LLT::pointer (5 , 32 ) ||
308
+ Ty == LLT::pointer (6 , 32 ))
309
+ return Ty;
310
+ return LLT ();
311
+ case SgprB64:
312
+ case VgprB64:
313
+ case UniInVgprB64:
314
+ if (Ty == LLT::scalar (64 ) || Ty == LLT::fixed_vector (2 , 32 ) ||
315
+ Ty == LLT::fixed_vector (4 , 16 ) || Ty == LLT::pointer (0 , 64 ) ||
316
+ Ty == LLT::pointer (1 , 64 ) || Ty == LLT::pointer (4 , 64 ))
317
+ return Ty;
318
+ return LLT ();
319
+ case SgprB96:
320
+ case VgprB96:
321
+ case UniInVgprB96:
322
+ if (Ty == LLT::scalar (96 ) || Ty == LLT::fixed_vector (3 , 32 ) ||
323
+ Ty == LLT::fixed_vector (6 , 16 ))
324
+ return Ty;
325
+ return LLT ();
326
+ case SgprB128:
327
+ case VgprB128:
328
+ case UniInVgprB128:
329
+ if (Ty == LLT::scalar (128 ) || Ty == LLT::fixed_vector (4 , 32 ) ||
330
+ Ty == LLT::fixed_vector (2 , 64 ))
331
+ return Ty;
332
+ return LLT ();
333
+ case SgprB256:
334
+ case VgprB256:
335
+ case UniInVgprB256:
336
+ if (Ty == LLT::scalar (256 ) || Ty == LLT::fixed_vector (8 , 32 ) ||
337
+ Ty == LLT::fixed_vector (4 , 64 ) || Ty == LLT::fixed_vector (16 , 16 ))
338
+ return Ty;
339
+ return LLT ();
340
+ case SgprB512:
341
+ case VgprB512:
342
+ case UniInVgprB512:
343
+ if (Ty == LLT::scalar (512 ) || Ty == LLT::fixed_vector (16 , 32 ) ||
344
+ Ty == LLT::fixed_vector (8 , 64 ))
345
+ return Ty;
346
+ return LLT ();
152
347
default :
153
348
return LLT ();
154
349
}
@@ -163,10 +358,26 @@ RegBankLegalizeHelper::getRBFromID(RegBankLLTMapingApplyID ID) {
163
358
case Sgpr16:
164
359
case Sgpr32:
165
360
case Sgpr64:
361
+ case SgprP1:
362
+ case SgprP3:
363
+ case SgprP4:
364
+ case SgprP5:
166
365
case SgprV4S32:
366
+ case SgprB32:
367
+ case SgprB64:
368
+ case SgprB96:
369
+ case SgprB128:
370
+ case SgprB256:
371
+ case SgprB512:
167
372
case UniInVcc:
168
373
case UniInVgprS32:
169
374
case UniInVgprV4S32:
375
+ case UniInVgprB32:
376
+ case UniInVgprB64:
377
+ case UniInVgprB96:
378
+ case UniInVgprB128:
379
+ case UniInVgprB256:
380
+ case UniInVgprB512:
170
381
case Sgpr32Trunc:
171
382
case Sgpr32AExt:
172
383
case Sgpr32AExtBoolInReg:
@@ -176,7 +387,16 @@ RegBankLegalizeHelper::getRBFromID(RegBankLLTMapingApplyID ID) {
176
387
case Vgpr32:
177
388
case Vgpr64:
178
389
case VgprP1:
390
+ case VgprP3:
391
+ case VgprP4:
392
+ case VgprP5:
179
393
case VgprV4S32:
394
+ case VgprB32:
395
+ case VgprB64:
396
+ case VgprB96:
397
+ case VgprB128:
398
+ case VgprB256:
399
+ case VgprB512:
180
400
return VgprRB;
181
401
182
402
default :
@@ -202,17 +422,42 @@ void RegBankLegalizeHelper::applyMappingDst(
202
422
case Sgpr16:
203
423
case Sgpr32:
204
424
case Sgpr64:
425
+ case SgprP1:
426
+ case SgprP3:
427
+ case SgprP4:
428
+ case SgprP5:
205
429
case SgprV4S32:
206
430
case Vgpr32:
207
431
case Vgpr64:
208
432
case VgprP1:
433
+ case VgprP3:
434
+ case VgprP4:
435
+ case VgprP5:
209
436
case VgprV4S32: {
210
437
assert (Ty == getTyFromID (MethodIDs[OpIdx]));
211
438
assert (RB == getRBFromID (MethodIDs[OpIdx]));
212
439
break ;
213
440
}
214
441
215
- // uniform in vcc/vgpr: scalars and vectors
442
+ // sgpr and vgpr B-types
443
+ case SgprB32:
444
+ case SgprB64:
445
+ case SgprB96:
446
+ case SgprB128:
447
+ case SgprB256:
448
+ case SgprB512:
449
+ case VgprB32:
450
+ case VgprB64:
451
+ case VgprB96:
452
+ case VgprB128:
453
+ case VgprB256:
454
+ case VgprB512: {
455
+ assert (Ty == getBTyFromID (MethodIDs[OpIdx], Ty));
456
+ assert (RB == getRBFromID (MethodIDs[OpIdx]));
457
+ break ;
458
+ }
459
+
460
+ // uniform in vcc/vgpr: scalars, vectors and B-types
216
461
case UniInVcc: {
217
462
assert (Ty == S1);
218
463
assert (RB == SgprRB);
@@ -229,6 +474,17 @@ void RegBankLegalizeHelper::applyMappingDst(
229
474
AMDGPU::buildReadAnyLaneDst (B, MI, RBI);
230
475
break ;
231
476
}
477
+ case UniInVgprB32:
478
+ case UniInVgprB64:
479
+ case UniInVgprB96:
480
+ case UniInVgprB128:
481
+ case UniInVgprB256:
482
+ case UniInVgprB512: {
483
+ assert (Ty == getBTyFromID (MethodIDs[OpIdx], Ty));
484
+ assert (RB == SgprRB);
485
+ AMDGPU::buildReadAnyLaneDst (B, MI, RBI);
486
+ break ;
487
+ }
232
488
233
489
// sgpr trunc
234
490
case Sgpr32Trunc: {
@@ -279,16 +535,34 @@ void RegBankLegalizeHelper::applyMappingSrc(
279
535
case Sgpr16:
280
536
case Sgpr32:
281
537
case Sgpr64:
538
+ case SgprP1:
539
+ case SgprP3:
540
+ case SgprP4:
541
+ case SgprP5:
282
542
case SgprV4S32: {
283
543
assert (Ty == getTyFromID (MethodIDs[i]));
284
544
assert (RB == getRBFromID (MethodIDs[i]));
285
545
break ;
286
546
}
547
+ // sgpr B-types
548
+ case SgprB32:
549
+ case SgprB64:
550
+ case SgprB96:
551
+ case SgprB128:
552
+ case SgprB256:
553
+ case SgprB512: {
554
+ assert (Ty == getBTyFromID (MethodIDs[i], Ty));
555
+ assert (RB == getRBFromID (MethodIDs[i]));
556
+ break ;
557
+ }
287
558
288
559
// vgpr scalars, pointers and vectors
289
560
case Vgpr32:
290
561
case Vgpr64:
291
562
case VgprP1:
563
+ case VgprP3:
564
+ case VgprP4:
565
+ case VgprP5:
292
566
case VgprV4S32: {
293
567
assert (Ty == getTyFromID (MethodIDs[i]));
294
568
if (RB != VgprRB) {
@@ -298,6 +572,21 @@ void RegBankLegalizeHelper::applyMappingSrc(
298
572
}
299
573
break ;
300
574
}
575
+ // vgpr B-types
576
+ case VgprB32:
577
+ case VgprB64:
578
+ case VgprB96:
579
+ case VgprB128:
580
+ case VgprB256:
581
+ case VgprB512: {
582
+ assert (Ty == getBTyFromID (MethodIDs[i], Ty));
583
+ if (RB != VgprRB) {
584
+ auto CopyToVgpr =
585
+ B.buildCopy (createVgpr (getBTyFromID (MethodIDs[i], Ty)), Reg);
586
+ Op.setReg (CopyToVgpr.getReg (0 ));
587
+ }
588
+ break ;
589
+ }
301
590
302
591
// sgpr and vgpr scalars with extend
303
592
case Sgpr32AExt: {
@@ -372,7 +661,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
372
661
// We accept all types that can fit in some register class.
373
662
// Uniform G_PHIs have all sgpr registers.
374
663
// Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
375
- if (Ty == LLT::scalar (32 )) {
664
+ if (Ty == LLT::scalar (32 ) || Ty == LLT::pointer ( 4 , 64 ) ) {
376
665
return ;
377
666
}
378
667
0 commit comments