@@ -572,9 +572,10 @@ MCRegister SIRegisterInfo::reservedPrivateSegmentBufferReg(
572
572
std::pair<unsigned , unsigned >
573
573
SIRegisterInfo::getMaxNumVectorRegs (const MachineFunction &MF) const {
574
574
const SIMachineFunctionInfo *MFI = MF.getInfo <SIMachineFunctionInfo>();
575
- unsigned MaxNumVGPRs = ST.getMaxNumVGPRs (MF);
576
- unsigned MaxNumAGPRs = MaxNumVGPRs;
577
- unsigned TotalNumVGPRs = AMDGPU::VGPR_32RegClass.getNumRegs ();
575
+ const unsigned MaxVectorRegs = ST.getMaxNumVGPRs (MF);
576
+
577
+ unsigned MaxNumVGPRs = MaxVectorRegs;
578
+ unsigned MaxNumAGPRs = 0 ;
578
579
579
580
// On GFX90A, the number of VGPRs and AGPRs need not be equal. Theoretically,
580
581
// a wave may have up to 512 total vector registers combining together both
@@ -585,16 +586,49 @@ SIRegisterInfo::getMaxNumVectorRegs(const MachineFunction &MF) const {
585
586
// TODO: it shall be possible to estimate maximum AGPR/VGPR pressure and split
586
587
// register file accordingly.
587
588
if (ST.hasGFX90AInsts ()) {
588
- if (MFI->mayNeedAGPRs ()) {
589
- MaxNumVGPRs /= 2 ;
590
- MaxNumAGPRs = MaxNumVGPRs;
589
+ unsigned MinNumAGPRs = 0 ;
590
+ const unsigned TotalNumAGPRs = AMDGPU::AGPR_32RegClass.getNumRegs ();
591
+ const unsigned TotalNumVGPRs = AMDGPU::VGPR_32RegClass.getNumRegs ();
592
+
593
+ const std::pair<unsigned , unsigned > DefaultNumAGPR = {~0u , ~0u };
594
+
595
+ // TODO: Replace amdgpu-no-agpr with amdgpu-agpr-alloc=0
596
+ // TODO: Move this logic into subtarget on IR function
597
+ //
598
+ // TODO: The lower bound should probably force the number of required
599
+ // registers up, overriding amdgpu-waves-per-eu.
600
+ std::tie (MinNumAGPRs, MaxNumAGPRs) = AMDGPU::getIntegerPairAttribute (
601
+ MF.getFunction (), " amdgpu-agpr-alloc" , DefaultNumAGPR,
602
+ /* OnlyFirstRequired=*/ true );
603
+
604
+ if (MinNumAGPRs == DefaultNumAGPR.first ) {
605
+ // Default to splitting half the registers if AGPRs are required.
606
+
607
+ if (MFI->mayNeedAGPRs ())
608
+ MinNumAGPRs = MaxNumAGPRs = MaxVectorRegs / 2 ;
609
+ else
610
+ MinNumAGPRs = 0 ;
591
611
} else {
592
- if (MaxNumVGPRs > TotalNumVGPRs) {
593
- MaxNumAGPRs = MaxNumVGPRs - TotalNumVGPRs;
594
- MaxNumVGPRs = TotalNumVGPRs;
595
- } else
596
- MaxNumAGPRs = 0 ;
612
+ // Align to accum_offset's allocation granularity.
613
+ MinNumAGPRs = alignTo (MinNumAGPRs, 4 );
614
+
615
+ MinNumAGPRs = std::min (MinNumAGPRs, TotalNumAGPRs);
597
616
}
617
+
618
+ // Clamp values to be inbounds of our limits, and ensure min <= max.
619
+
620
+ MaxNumAGPRs = std::min (std::max (MinNumAGPRs, MaxNumAGPRs), MaxVectorRegs);
621
+ MinNumAGPRs = std::min (std::min (MinNumAGPRs, TotalNumAGPRs), MaxNumAGPRs);
622
+
623
+ MaxNumVGPRs = std::min (MaxVectorRegs - MinNumAGPRs, TotalNumVGPRs);
624
+ MaxNumAGPRs = std::min (MaxVectorRegs - MaxNumVGPRs, MaxNumAGPRs);
625
+
626
+ assert (MaxNumVGPRs + MaxNumAGPRs <= MaxVectorRegs &&
627
+ MaxNumAGPRs <= TotalNumAGPRs && MaxNumVGPRs <= TotalNumVGPRs &&
628
+ " invalid register counts" );
629
+ } else if (ST.hasMAIInsts ()) {
630
+ // On gfx908 the number of AGPRs always equals the number of VGPRs.
631
+ MaxNumAGPRs = MaxNumVGPRs = MaxVectorRegs;
598
632
}
599
633
600
634
return std::pair (MaxNumVGPRs, MaxNumAGPRs);
0 commit comments