@@ -295,6 +295,63 @@ collectVirtualRegUses(SmallVectorImpl<RegisterMaskPair> &RegMaskPairs,
295
295
}
296
296
}
297
297
298
+ // / Mostly copy/paste from CodeGen/RegisterPressure.cpp
299
+ static LaneBitmask getLanesWithProperty (
300
+ const LiveIntervals &LIS, const MachineRegisterInfo &MRI,
301
+ bool TrackLaneMasks, Register RegUnit, SlotIndex Pos,
302
+ LaneBitmask SafeDefault,
303
+ function_ref<bool (const LiveRange &LR, SlotIndex Pos)> Property) {
304
+ if (RegUnit.isVirtual ()) {
305
+ const LiveInterval &LI = LIS.getInterval (RegUnit);
306
+ LaneBitmask Result;
307
+ if (TrackLaneMasks && LI.hasSubRanges ()) {
308
+ for (const LiveInterval::SubRange &SR : LI.subranges ()) {
309
+ if (Property (SR, Pos))
310
+ Result |= SR.LaneMask ;
311
+ }
312
+ } else if (Property (LI, Pos)) {
313
+ Result = TrackLaneMasks ? MRI.getMaxLaneMaskForVReg (RegUnit)
314
+ : LaneBitmask::getAll ();
315
+ }
316
+
317
+ return Result;
318
+ }
319
+
320
+ const LiveRange *LR = LIS.getCachedRegUnit (RegUnit);
321
+ if (LR == nullptr )
322
+ return SafeDefault;
323
+ return Property (*LR, Pos) ? LaneBitmask::getAll () : LaneBitmask::getNone ();
324
+ }
325
+
326
+ // / Mostly copy/paste from CodeGen/RegisterPressure.cpp
327
+ // / Helper to find a vreg use between two indices {PriorUseIdx, NextUseIdx}.
328
+ // / The query starts with a lane bitmask which gets lanes/bits removed for every
329
+ // / use we find.
330
+ static LaneBitmask findUseBetween (unsigned Reg, LaneBitmask LastUseMask,
331
+ SlotIndex PriorUseIdx, SlotIndex NextUseIdx,
332
+ const MachineRegisterInfo &MRI,
333
+ const SIRegisterInfo *TRI,
334
+ const LiveIntervals *LIS,
335
+ bool Upward = false ) {
336
+ for (const MachineOperand &MO : MRI.use_nodbg_operands (Reg)) {
337
+ if (MO.isUndef ())
338
+ continue ;
339
+ const MachineInstr *MI = MO.getParent ();
340
+ SlotIndex InstSlot = LIS->getInstructionIndex (*MI).getRegSlot ();
341
+ bool InRange = Upward ? (InstSlot > PriorUseIdx && InstSlot <= NextUseIdx)
342
+ : (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx);
343
+ if (!InRange)
344
+ continue ;
345
+
346
+ unsigned SubRegIdx = MO.getSubReg ();
347
+ LaneBitmask UseMask = TRI->getSubRegIndexLaneMask (SubRegIdx);
348
+ LastUseMask &= ~UseMask;
349
+ if (LastUseMask.none ())
350
+ return LaneBitmask::getNone ();
351
+ }
352
+ return LastUseMask;
353
+ }
354
+
298
355
// /////////////////////////////////////////////////////////////////////////////
299
356
// GCNRPTracker
300
357
@@ -353,17 +410,28 @@ void GCNRPTracker::reset(const MachineInstr &MI,
353
410
MaxPressure = CurPressure = getRegPressure (*MRI, LiveRegs);
354
411
}
355
412
356
- // //////////////////////////////////////////////////////////////////////////////
357
- // GCNUpwardRPTracker
358
-
359
- void GCNUpwardRPTracker::reset (const MachineRegisterInfo &MRI_,
360
- const LiveRegSet &LiveRegs_) {
413
+ void GCNRPTracker::reset (const MachineRegisterInfo &MRI_,
414
+ const LiveRegSet &LiveRegs_) {
361
415
MRI = &MRI_;
362
416
LiveRegs = LiveRegs_;
363
417
LastTrackedMI = nullptr ;
364
418
MaxPressure = CurPressure = getRegPressure (MRI_, LiveRegs_);
365
419
}
366
420
421
+ // / Mostly copy/paste from CodeGen/RegisterPressure.cpp
422
+ LaneBitmask GCNRPTracker::getLastUsedLanes (Register RegUnit,
423
+ SlotIndex Pos) const {
424
+ return getLanesWithProperty (
425
+ LIS, *MRI, true , RegUnit, Pos.getBaseIndex (), LaneBitmask::getNone (),
426
+ [](const LiveRange &LR, SlotIndex Pos) {
427
+ const LiveRange::Segment *S = LR.getSegmentContaining (Pos);
428
+ return S != nullptr && S->end == Pos.getRegSlot ();
429
+ });
430
+ }
431
+
432
+ // //////////////////////////////////////////////////////////////////////////////
433
+ // GCNUpwardRPTracker
434
+
367
435
void GCNUpwardRPTracker::recede (const MachineInstr &MI) {
368
436
assert (MRI && " call reset first" );
369
437
@@ -440,25 +508,37 @@ bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
440
508
return true ;
441
509
}
442
510
443
- bool GCNDownwardRPTracker::advanceBeforeNext () {
511
+ bool GCNDownwardRPTracker::advanceBeforeNext (MachineInstr *MI,
512
+ bool UseInternalIterator) {
444
513
assert (MRI && " call reset first" );
445
- if (!LastTrackedMI)
446
- return NextMI == MBBEnd;
447
-
448
- assert (NextMI == MBBEnd || !NextMI->isDebugInstr ());
514
+ SlotIndex SI;
515
+ const MachineInstr *CurrMI;
516
+ if (UseInternalIterator) {
517
+ if (!LastTrackedMI)
518
+ return NextMI == MBBEnd;
519
+
520
+ assert (NextMI == MBBEnd || !NextMI->isDebugInstr ());
521
+ CurrMI = LastTrackedMI;
522
+
523
+ SI = NextMI == MBBEnd
524
+ ? LIS.getInstructionIndex (*LastTrackedMI).getDeadSlot ()
525
+ : LIS.getInstructionIndex (*NextMI).getBaseIndex ();
526
+ } else { // ! UseInternalIterator
527
+ SI = LIS.getInstructionIndex (*MI).getBaseIndex ();
528
+ CurrMI = MI;
529
+ }
449
530
450
- SlotIndex SI = NextMI == MBBEnd
451
- ? LIS.getInstructionIndex (*LastTrackedMI).getDeadSlot ()
452
- : LIS.getInstructionIndex (*NextMI).getBaseIndex ();
453
531
assert (SI.isValid ());
454
532
455
533
// Remove dead registers or mask bits.
456
534
SmallSet<Register, 8 > SeenRegs;
457
- for (auto &MO : LastTrackedMI ->operands ()) {
535
+ for (auto &MO : CurrMI ->operands ()) {
458
536
if (!MO.isReg () || !MO.getReg ().isVirtual ())
459
537
continue ;
460
538
if (MO.isUse () && !MO.readsReg ())
461
539
continue ;
540
+ if (!UseInternalIterator && MO.isDef ())
541
+ continue ;
462
542
if (!SeenRegs.insert (MO.getReg ()).second )
463
543
continue ;
464
544
const LiveInterval &LI = LIS.getInterval (MO.getReg ());
@@ -491,15 +571,22 @@ bool GCNDownwardRPTracker::advanceBeforeNext() {
491
571
492
572
LastTrackedMI = nullptr ;
493
573
494
- return NextMI == MBBEnd;
574
+ return UseInternalIterator && ( NextMI == MBBEnd) ;
495
575
}
496
576
497
- void GCNDownwardRPTracker::advanceToNext () {
498
- LastTrackedMI = &*NextMI++;
499
- NextMI = skipDebugInstructionsForward (NextMI, MBBEnd);
577
+ void GCNDownwardRPTracker::advanceToNext (MachineInstr *MI,
578
+ bool UseInternalIterator) {
579
+ if (UseInternalIterator) {
580
+ LastTrackedMI = &*NextMI++;
581
+ NextMI = skipDebugInstructionsForward (NextMI, MBBEnd);
582
+ } else {
583
+ LastTrackedMI = MI;
584
+ }
585
+
586
+ const MachineInstr *CurrMI = LastTrackedMI;
500
587
501
588
// Add new registers or mask bits.
502
- for (const auto &MO : LastTrackedMI ->all_defs ()) {
589
+ for (const auto &MO : CurrMI ->all_defs ()) {
503
590
Register Reg = MO.getReg ();
504
591
if (!Reg.isVirtual ())
505
592
continue ;
@@ -512,11 +599,16 @@ void GCNDownwardRPTracker::advanceToNext() {
512
599
MaxPressure = max (MaxPressure, CurPressure);
513
600
}
514
601
515
- bool GCNDownwardRPTracker::advance () {
516
- if (NextMI == MBBEnd)
602
+ bool GCNDownwardRPTracker::advance (MachineInstr *MI, bool UseInternalIterator ) {
603
+ if (UseInternalIterator && NextMI == MBBEnd)
517
604
return false ;
518
- advanceBeforeNext ();
519
- advanceToNext ();
605
+
606
+ advanceBeforeNext (MI, UseInternalIterator);
607
+ advanceToNext (MI, UseInternalIterator);
608
+ if (!UseInternalIterator) {
609
+ // We must remove any dead def lanes from the current RP
610
+ advanceBeforeNext (MI, true );
611
+ }
520
612
return true ;
521
613
}
522
614
@@ -558,6 +650,67 @@ Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
558
650
});
559
651
}
560
652
653
+ GCNRegPressure
654
+ GCNDownwardRPTracker::bumpDownwardPressure (const MachineInstr *MI,
655
+ const SIRegisterInfo *TRI) const {
656
+ assert (!MI->isDebugOrPseudoInstr () && " Expect a nondebug instruction." );
657
+
658
+ SlotIndex SlotIdx;
659
+ SlotIdx = LIS.getInstructionIndex (*MI).getRegSlot ();
660
+
661
+ // Account for register pressure similar to RegPressureTracker::recede().
662
+ RegisterOperands RegOpers;
663
+ RegOpers.collect (*MI, *TRI, *MRI, true , /* IgnoreDead=*/ false );
664
+ RegOpers.adjustLaneLiveness (LIS, *MRI, SlotIdx);
665
+ GCNRegPressure TempPressure = CurPressure;
666
+
667
+ for (const RegisterMaskPair &Use : RegOpers.Uses ) {
668
+ Register Reg = Use.RegUnit ;
669
+ if (!Reg.isVirtual ())
670
+ continue ;
671
+ LaneBitmask LastUseMask = getLastUsedLanes (Reg, SlotIdx);
672
+ if (LastUseMask.none ())
673
+ continue ;
674
+ // The LastUseMask is queried from the liveness information of instruction
675
+ // which may be further down the schedule. Some lanes may actually not be
676
+ // last uses for the current position.
677
+ // FIXME: allow the caller to pass in the list of vreg uses that remain
678
+ // to be bottom-scheduled to avoid searching uses at each query.
679
+ SlotIndex CurrIdx;
680
+ const MachineBasicBlock *MBB = MI->getParent ();
681
+ MachineBasicBlock::const_iterator IdxPos = skipDebugInstructionsForward (
682
+ LastTrackedMI ? LastTrackedMI : MBB->begin (), MBB->end ());
683
+ if (IdxPos == MBB->end ()) {
684
+ CurrIdx = LIS.getMBBEndIdx (MBB);
685
+ } else {
686
+ CurrIdx = LIS.getInstructionIndex (*IdxPos).getRegSlot ();
687
+ }
688
+
689
+ LastUseMask =
690
+ findUseBetween (Reg, LastUseMask, CurrIdx, SlotIdx, *MRI, TRI, &LIS);
691
+ if (LastUseMask.none ())
692
+ continue ;
693
+
694
+ LaneBitmask LiveMask =
695
+ LiveRegs.contains (Reg) ? LiveRegs.at (Reg) : LaneBitmask (0 );
696
+ LaneBitmask NewMask = LiveMask & ~LastUseMask;
697
+ TempPressure.inc (Reg, LiveMask, NewMask, *MRI);
698
+ }
699
+
700
+ // Generate liveness for defs.
701
+ for (const RegisterMaskPair &Def : RegOpers.Defs ) {
702
+ Register Reg = Def.RegUnit ;
703
+ if (!Reg.isVirtual ())
704
+ continue ;
705
+ LaneBitmask LiveMask =
706
+ LiveRegs.contains (Reg) ? LiveRegs.at (Reg) : LaneBitmask (0 );
707
+ LaneBitmask NewMask = LiveMask | Def.LaneMask ;
708
+ TempPressure.inc (Reg, LiveMask, NewMask, *MRI);
709
+ }
710
+
711
+ return TempPressure;
712
+ }
713
+
561
714
bool GCNUpwardRPTracker::isValid () const {
562
715
const auto &SI = LIS.getInstructionIndex (*LastTrackedMI).getBaseIndex ();
563
716
const auto LISLR = llvm::getLiveRegs (SI, LIS, *MRI);
0 commit comments