@@ -36,6 +36,83 @@ void RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) {
36
36
lower (MI, Mapping, WaterfallSgprs);
37
37
}
38
38
39
+ void RegBankLegalizeHelper::splitLoad (MachineInstr &MI,
40
+ ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
41
+ MachineFunction &MF = B.getMF ();
42
+ assert (MI.getNumMemOperands () == 1 );
43
+ MachineMemOperand &BaseMMO = **MI.memoperands_begin ();
44
+ Register Dst = MI.getOperand (0 ).getReg ();
45
+ const RegisterBank *DstRB = MRI.getRegBankOrNull (Dst);
46
+ Register Base = MI.getOperand (1 ).getReg ();
47
+ LLT PtrTy = MRI.getType (Base);
48
+ const RegisterBank *PtrRB = MRI.getRegBankOrNull (Base);
49
+ LLT OffsetTy = LLT::scalar (PtrTy.getSizeInBits ());
50
+ SmallVector<Register, 4 > LoadPartRegs;
51
+
52
+ unsigned ByteOffset = 0 ;
53
+ for (LLT PartTy : LLTBreakdown) {
54
+ Register BasePlusOffset;
55
+ if (ByteOffset == 0 ) {
56
+ BasePlusOffset = Base;
57
+ } else {
58
+ auto Offset = B.buildConstant ({PtrRB, OffsetTy}, ByteOffset);
59
+ BasePlusOffset = B.buildPtrAdd ({PtrRB, PtrTy}, Base, Offset).getReg (0 );
60
+ }
61
+ auto *OffsetMMO = MF.getMachineMemOperand (&BaseMMO, ByteOffset, PartTy);
62
+ auto LoadPart = B.buildLoad ({DstRB, PartTy}, BasePlusOffset, *OffsetMMO);
63
+ LoadPartRegs.push_back (LoadPart.getReg (0 ));
64
+ ByteOffset += PartTy.getSizeInBytes ();
65
+ }
66
+
67
+ if (!MergeTy.isValid ()) {
68
+ // Loads are of same size, concat or merge them together.
69
+ B.buildMergeLikeInstr (Dst, LoadPartRegs);
70
+ } else {
71
+ // Loads are not all of same size, need to unmerge them to smaller pieces
72
+ // of MergeTy type, then merge pieces to Dst.
73
+ SmallVector<Register, 4 > MergeTyParts;
74
+ for (Register Reg : LoadPartRegs) {
75
+ if (MRI.getType (Reg) == MergeTy) {
76
+ MergeTyParts.push_back (Reg);
77
+ } else {
78
+ auto Unmerge = B.buildUnmerge ({DstRB, MergeTy}, Reg);
79
+ for (unsigned i = 0 ; i < Unmerge->getNumOperands () - 1 ; ++i)
80
+ MergeTyParts.push_back (Unmerge.getReg (i));
81
+ }
82
+ }
83
+ B.buildMergeLikeInstr (Dst, MergeTyParts);
84
+ }
85
+ MI.eraseFromParent ();
86
+ }
87
+
88
+ void RegBankLegalizeHelper::widenLoad (MachineInstr &MI, LLT WideTy,
89
+ LLT MergeTy) {
90
+ MachineFunction &MF = B.getMF ();
91
+ assert (MI.getNumMemOperands () == 1 );
92
+ MachineMemOperand &BaseMMO = **MI.memoperands_begin ();
93
+ Register Dst = MI.getOperand (0 ).getReg ();
94
+ const RegisterBank *DstRB = MRI.getRegBankOrNull (Dst);
95
+ Register Base = MI.getOperand (1 ).getReg ();
96
+
97
+ MachineMemOperand *WideMMO = MF.getMachineMemOperand (&BaseMMO, 0 , WideTy);
98
+ auto WideLoad = B.buildLoad ({DstRB, WideTy}, Base, *WideMMO);
99
+
100
+ if (WideTy.isScalar ()) {
101
+ B.buildTrunc (Dst, WideLoad);
102
+ } else {
103
+ SmallVector<Register, 4 > MergeTyParts;
104
+ auto Unmerge = B.buildUnmerge ({DstRB, MergeTy}, WideLoad);
105
+
106
+ LLT DstTy = MRI.getType (Dst);
107
+ unsigned NumElts = DstTy.getSizeInBits () / MergeTy.getSizeInBits ();
108
+ for (unsigned i = 0 ; i < NumElts; ++i) {
109
+ MergeTyParts.push_back (Unmerge.getReg (i));
110
+ }
111
+ B.buildMergeLikeInstr (Dst, MergeTyParts);
112
+ }
113
+ MI.eraseFromParent ();
114
+ }
115
+
39
116
void RegBankLegalizeHelper::lower (MachineInstr &MI,
40
117
const RegBankLLTMapping &Mapping,
41
118
SmallSet<Register, 4 > &WaterfallSgprs) {
@@ -114,6 +191,50 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
114
191
MI.eraseFromParent ();
115
192
break ;
116
193
}
194
+ case SplitLoad: {
195
+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
196
+ unsigned Size = DstTy.getSizeInBits ();
197
+ // Even split to 128-bit loads
198
+ if (Size > 128 ) {
199
+ LLT B128;
200
+ if (DstTy.isVector ()) {
201
+ LLT EltTy = DstTy.getElementType ();
202
+ B128 = LLT::fixed_vector (128 / EltTy.getSizeInBits (), EltTy);
203
+ } else {
204
+ B128 = LLT::scalar (128 );
205
+ }
206
+ if (Size / 128 == 2 )
207
+ splitLoad (MI, {B128, B128});
208
+ if (Size / 128 == 4 )
209
+ splitLoad (MI, {B128, B128, B128, B128});
210
+ }
211
+ // 64 and 32 bit load
212
+ else if (DstTy == S96)
213
+ splitLoad (MI, {S64, S32}, S32);
214
+ else if (DstTy == V3S32)
215
+ splitLoad (MI, {V2S32, S32}, S32);
216
+ else if (DstTy == V6S16)
217
+ splitLoad (MI, {V4S16, V2S16}, V2S16);
218
+ else {
219
+ MI.dump ();
220
+ llvm_unreachable (" SplitLoad type not supported" );
221
+ }
222
+ break ;
223
+ }
224
+ case WidenLoad: {
225
+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
226
+ if (DstTy == S96)
227
+ widenLoad (MI, S128);
228
+ else if (DstTy == V3S32)
229
+ widenLoad (MI, V4S32, S32);
230
+ else if (DstTy == V6S16)
231
+ widenLoad (MI, V8S16, V2S16);
232
+ else {
233
+ MI.dump ();
234
+ llvm_unreachable (" WidenLoad type not supported" );
235
+ }
236
+ break ;
237
+ }
117
238
}
118
239
119
240
// TODO: executeInWaterfallLoop(... WaterfallSgprs)
@@ -137,13 +258,73 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMapingApplyID ID) {
137
258
case Sgpr64:
138
259
case Vgpr64:
139
260
return LLT::scalar (64 );
140
-
261
+ case SgprP1:
262
+ case VgprP1:
263
+ return LLT::pointer (1 , 64 );
264
+ case SgprP3:
265
+ case VgprP3:
266
+ return LLT::pointer (3 , 32 );
267
+ case SgprP4:
268
+ case VgprP4:
269
+ return LLT::pointer (4 , 64 );
270
+ case SgprP5:
271
+ case VgprP5:
272
+ return LLT::pointer (5 , 32 );
141
273
case SgprV4S32:
142
274
case VgprV4S32:
143
275
case UniInVgprV4S32:
144
276
return LLT::fixed_vector (4 , 32 );
145
- case VgprP1:
146
- return LLT::pointer (1 , 64 );
277
+ default :
278
+ return LLT ();
279
+ }
280
+ }
281
+
282
+ LLT RegBankLegalizeHelper::getBTyFromID (RegBankLLTMapingApplyID ID, LLT Ty) {
283
+ switch (ID) {
284
+ case SgprB32:
285
+ case VgprB32:
286
+ case UniInVgprB32:
287
+ if (Ty == LLT::scalar (32 ) || Ty == LLT::fixed_vector (2 , 16 ) ||
288
+ Ty == LLT::pointer (3 , 32 ) || Ty == LLT::pointer (5 , 32 ) ||
289
+ Ty == LLT::pointer (6 , 32 ))
290
+ return Ty;
291
+ return LLT ();
292
+ case SgprB64:
293
+ case VgprB64:
294
+ case UniInVgprB64:
295
+ if (Ty == LLT::scalar (64 ) || Ty == LLT::fixed_vector (2 , 32 ) ||
296
+ Ty == LLT::fixed_vector (4 , 16 ) || Ty == LLT::pointer (0 , 64 ) ||
297
+ Ty == LLT::pointer (1 , 64 ) || Ty == LLT::pointer (4 , 64 ))
298
+ return Ty;
299
+ return LLT ();
300
+ case SgprB96:
301
+ case VgprB96:
302
+ case UniInVgprB96:
303
+ if (Ty == LLT::scalar (96 ) || Ty == LLT::fixed_vector (3 , 32 ) ||
304
+ Ty == LLT::fixed_vector (6 , 16 ))
305
+ return Ty;
306
+ return LLT ();
307
+ case SgprB128:
308
+ case VgprB128:
309
+ case UniInVgprB128:
310
+ if (Ty == LLT::scalar (128 ) || Ty == LLT::fixed_vector (4 , 32 ) ||
311
+ Ty == LLT::fixed_vector (2 , 64 ))
312
+ return Ty;
313
+ return LLT ();
314
+ case SgprB256:
315
+ case VgprB256:
316
+ case UniInVgprB256:
317
+ if (Ty == LLT::scalar (256 ) || Ty == LLT::fixed_vector (8 , 32 ) ||
318
+ Ty == LLT::fixed_vector (4 , 64 ) || Ty == LLT::fixed_vector (16 , 16 ))
319
+ return Ty;
320
+ return LLT ();
321
+ case SgprB512:
322
+ case VgprB512:
323
+ case UniInVgprB512:
324
+ if (Ty == LLT::scalar (512 ) || Ty == LLT::fixed_vector (16 , 32 ) ||
325
+ Ty == LLT::fixed_vector (8 , 64 ))
326
+ return Ty;
327
+ return LLT ();
147
328
default :
148
329
return LLT ();
149
330
}
@@ -158,10 +339,26 @@ RegBankLegalizeHelper::getRBFromID(RegBankLLTMapingApplyID ID) {
158
339
case Sgpr16:
159
340
case Sgpr32:
160
341
case Sgpr64:
342
+ case SgprP1:
343
+ case SgprP3:
344
+ case SgprP4:
345
+ case SgprP5:
161
346
case SgprV4S32:
347
+ case SgprB32:
348
+ case SgprB64:
349
+ case SgprB96:
350
+ case SgprB128:
351
+ case SgprB256:
352
+ case SgprB512:
162
353
case UniInVcc:
163
354
case UniInVgprS32:
164
355
case UniInVgprV4S32:
356
+ case UniInVgprB32:
357
+ case UniInVgprB64:
358
+ case UniInVgprB96:
359
+ case UniInVgprB128:
360
+ case UniInVgprB256:
361
+ case UniInVgprB512:
165
362
case Sgpr32Trunc:
166
363
case Sgpr32AExt:
167
364
case Sgpr32AExtBoolInReg:
@@ -171,7 +368,16 @@ RegBankLegalizeHelper::getRBFromID(RegBankLLTMapingApplyID ID) {
171
368
case Vgpr32:
172
369
case Vgpr64:
173
370
case VgprP1:
371
+ case VgprP3:
372
+ case VgprP4:
373
+ case VgprP5:
174
374
case VgprV4S32:
375
+ case VgprB32:
376
+ case VgprB64:
377
+ case VgprB96:
378
+ case VgprB128:
379
+ case VgprB256:
380
+ case VgprB512:
175
381
return VgprRB;
176
382
177
383
default :
@@ -197,17 +403,42 @@ void RegBankLegalizeHelper::applyMappingDst(
197
403
case Sgpr16:
198
404
case Sgpr32:
199
405
case Sgpr64:
406
+ case SgprP1:
407
+ case SgprP3:
408
+ case SgprP4:
409
+ case SgprP5:
200
410
case SgprV4S32:
201
411
case Vgpr32:
202
412
case Vgpr64:
203
413
case VgprP1:
414
+ case VgprP3:
415
+ case VgprP4:
416
+ case VgprP5:
204
417
case VgprV4S32: {
205
418
assert (Ty == getTyFromID (MethodIDs[OpIdx]));
206
419
assert (RB == getRBFromID (MethodIDs[OpIdx]));
207
420
break ;
208
421
}
209
422
210
- // uniform in vcc/vgpr: scalars and vectors
423
+ // sgpr and vgpr B-types
424
+ case SgprB32:
425
+ case SgprB64:
426
+ case SgprB96:
427
+ case SgprB128:
428
+ case SgprB256:
429
+ case SgprB512:
430
+ case VgprB32:
431
+ case VgprB64:
432
+ case VgprB96:
433
+ case VgprB128:
434
+ case VgprB256:
435
+ case VgprB512: {
436
+ assert (Ty == getBTyFromID (MethodIDs[OpIdx], Ty));
437
+ assert (RB == getRBFromID (MethodIDs[OpIdx]));
438
+ break ;
439
+ }
440
+
441
+ // uniform in vcc/vgpr: scalars, vectors and B-types
211
442
case UniInVcc: {
212
443
assert (Ty == S1);
213
444
assert (RB == SgprRB);
@@ -227,6 +458,20 @@ void RegBankLegalizeHelper::applyMappingDst(
227
458
buildReadAnyLane (B, Reg, NewVgprDst, RBI);
228
459
break ;
229
460
}
461
+ case UniInVgprB32:
462
+ case UniInVgprB64:
463
+ case UniInVgprB96:
464
+ case UniInVgprB128:
465
+ case UniInVgprB256:
466
+ case UniInVgprB512: {
467
+ assert (Ty == getBTyFromID (MethodIDs[OpIdx], Ty));
468
+ assert (RB == SgprRB);
469
+
470
+ Register NewVgprDst = MRI.createVirtualRegister ({VgprRB, Ty});
471
+ Op.setReg (NewVgprDst);
472
+ AMDGPU::buildReadAnyLane (B, Reg, NewVgprDst, RBI);
473
+ break ;
474
+ }
230
475
231
476
// sgpr trunc
232
477
case Sgpr32Trunc: {
@@ -278,16 +523,34 @@ void RegBankLegalizeHelper::applyMappingSrc(
278
523
case Sgpr16:
279
524
case Sgpr32:
280
525
case Sgpr64:
526
+ case SgprP1:
527
+ case SgprP3:
528
+ case SgprP4:
529
+ case SgprP5:
281
530
case SgprV4S32: {
282
531
assert (Ty == getTyFromID (MethodIDs[i]));
283
532
assert (RB == getRBFromID (MethodIDs[i]));
284
533
break ;
285
534
}
535
+ // sgpr B-types
536
+ case SgprB32:
537
+ case SgprB64:
538
+ case SgprB96:
539
+ case SgprB128:
540
+ case SgprB256:
541
+ case SgprB512: {
542
+ assert (Ty == getBTyFromID (MethodIDs[i], Ty));
543
+ assert (RB == getRBFromID (MethodIDs[i]));
544
+ break ;
545
+ }
286
546
287
547
// vgpr scalars, pointers and vectors
288
548
case Vgpr32:
289
549
case Vgpr64:
290
550
case VgprP1:
551
+ case VgprP3:
552
+ case VgprP4:
553
+ case VgprP5:
291
554
case VgprV4S32: {
292
555
assert (Ty == getTyFromID (MethodIDs[i]));
293
556
if (RB != VgprRB) {
@@ -296,6 +559,20 @@ void RegBankLegalizeHelper::applyMappingSrc(
296
559
}
297
560
break ;
298
561
}
562
+ // vgpr B-types
563
+ case VgprB32:
564
+ case VgprB64:
565
+ case VgprB96:
566
+ case VgprB128:
567
+ case VgprB256:
568
+ case VgprB512: {
569
+ assert (Ty == getBTyFromID (MethodIDs[i], Ty));
570
+ if (RB != VgprRB) {
571
+ auto CopyToVgpr = B.buildCopy ({VgprRB, Ty}, Reg);
572
+ Op.setReg (CopyToVgpr.getReg (0 ));
573
+ }
574
+ break ;
575
+ }
299
576
300
577
// sgpr and vgpr scalars with extend
301
578
case Sgpr32AExt: {
@@ -368,7 +645,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
368
645
// We accept all types that can fit in some register class.
369
646
// Uniform G_PHIs have all sgpr registers.
370
647
// Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
371
- if (Ty == LLT::scalar (32 )) {
648
+ if (Ty == LLT::scalar (32 ) || Ty == LLT::pointer ( 4 , 64 ) ) {
372
649
return ;
373
650
}
374
651
0 commit comments