20
20
#include " Utils/AArch64BaseInfo.h"
21
21
#include " llvm/ADT/ArrayRef.h"
22
22
#include " llvm/ADT/STLExtras.h"
23
+ #include " llvm/ADT/SmallSet.h"
23
24
#include " llvm/ADT/SmallVector.h"
24
25
#include " llvm/CodeGen/LivePhysRegs.h"
25
26
#include " llvm/CodeGen/MachineBasicBlock.h"
@@ -78,6 +79,18 @@ static cl::opt<unsigned>
78
79
BDisplacementBits (" aarch64-b-offset-bits" , cl::Hidden, cl::init(26 ),
79
80
cl::desc(" Restrict range of B instructions (DEBUG)" ));
80
81
82
+ static cl::opt<bool >
83
+ EnableAccReassociation (" aarch64-acc-reassoc" , cl::Hidden, cl::init(true ),
84
+ cl::desc(" Enable reassociation of accumulation chains" ));
85
+
86
+ static cl::opt<unsigned int >
87
+ MinAccumulatorDepth (" aarch64-acc-min-depth" , cl::Hidden, cl::init(8 ),
88
+ cl::desc(" Minimum length of accumulator chains required for the optimization to kick in" ));
89
+
90
+ static cl::opt<unsigned int >
91
+ MaxAccumulatorWidth (" aarch64-acc-max-width" , cl::Hidden, cl::init(3 ), cl::desc(" Maximum number of branches in the accumulator tree" ));
92
+
93
+
81
94
AArch64InstrInfo::AArch64InstrInfo (const AArch64Subtarget &STI)
82
95
: AArch64GenInstrInfo(AArch64::ADJCALLSTACKDOWN, AArch64::ADJCALLSTACKUP,
83
96
AArch64::CATCHRET),
@@ -6674,6 +6687,118 @@ static bool getMaddPatterns(MachineInstr &Root,
6674
6687
}
6675
6688
return Found;
6676
6689
}
6690
+
6691
+ static bool isAccumulationOpcode (unsigned Opcode) {
6692
+ switch (Opcode) {
6693
+ default :
6694
+ break ;
6695
+ case AArch64::UABALB_ZZZ_D:
6696
+ case AArch64::UABALB_ZZZ_H:
6697
+ case AArch64::UABALB_ZZZ_S:
6698
+ case AArch64::UABALT_ZZZ_D:
6699
+ case AArch64::UABALT_ZZZ_H:
6700
+ case AArch64::UABALT_ZZZ_S:
6701
+ case AArch64::UABALv16i8_v8i16:
6702
+ case AArch64::UABALv2i32_v2i64:
6703
+ case AArch64::UABALv4i16_v4i32:
6704
+ case AArch64::UABALv4i32_v2i64:
6705
+ case AArch64::UABALv8i16_v4i32:
6706
+ case AArch64::UABALv8i8_v8i16:
6707
+ return true ;
6708
+ }
6709
+
6710
+ return false ;
6711
+ }
6712
+
6713
+ static unsigned getAccumulationStartOpCode (unsigned AccumulationOpcode) {
6714
+ switch (AccumulationOpcode) {
6715
+ default :
6716
+ llvm_unreachable (" Unknown accumulator opcode" );
6717
+ case AArch64::UABALB_ZZZ_D:
6718
+ return AArch64::UABDLB_ZZZ_D;
6719
+ case AArch64::UABALB_ZZZ_H:
6720
+ return AArch64::UABDLB_ZZZ_H;
6721
+ case AArch64::UABALB_ZZZ_S:
6722
+ return AArch64::UABDLB_ZZZ_S;
6723
+ case AArch64::UABALT_ZZZ_D:
6724
+ return AArch64::UABDLT_ZZZ_D;
6725
+ case AArch64::UABALT_ZZZ_H:
6726
+ return AArch64::UABDLT_ZZZ_H;
6727
+ case AArch64::UABALT_ZZZ_S:
6728
+ return AArch64::UABDLT_ZZZ_S;
6729
+ case AArch64::UABALv16i8_v8i16:
6730
+ return AArch64::UABDLv16i8_v8i16;
6731
+ case AArch64::UABALv2i32_v2i64:
6732
+ return AArch64::UABDLv2i32_v2i64;
6733
+ case AArch64::UABALv4i16_v4i32:
6734
+ return AArch64::UABDLv4i16_v4i32;
6735
+ case AArch64::UABALv4i32_v2i64:
6736
+ return AArch64::UABDLv4i32_v2i64;
6737
+ case AArch64::UABALv8i16_v4i32:
6738
+ return AArch64::UABDLv8i16_v4i32;
6739
+ case AArch64::UABALv8i8_v8i16:
6740
+ return AArch64::UABDLv8i8_v8i16;
6741
+ }
6742
+ }
6743
+
6744
+ static void getAccumulatorChain (MachineInstr *CurrentInstr, MachineBasicBlock &MBB, MachineRegisterInfo &MRI, SmallVectorImpl<Register> &Chain) {
6745
+ // Walk up the chain of accumulation instructions and collect them in the vector.
6746
+ unsigned AccumulatorOpcode = CurrentInstr->getOpcode ();
6747
+ unsigned ChainStartOpCode = getAccumulationStartOpCode (AccumulatorOpcode);
6748
+ while (CurrentInstr && (canCombine (MBB, CurrentInstr->getOperand (1 ), AccumulatorOpcode) || canCombine (MBB, CurrentInstr->getOperand (1 ), ChainStartOpCode))) {
6749
+ Chain.push_back (CurrentInstr->getOperand (0 ).getReg ());
6750
+ CurrentInstr = MRI.getUniqueVRegDef (CurrentInstr->getOperand (1 ).getReg ());
6751
+ }
6752
+
6753
+ // Add the instruction at the top of the chain.
6754
+ if (CurrentInstr->getOpcode () == ChainStartOpCode)
6755
+ Chain.push_back (CurrentInstr->getOperand (0 ).getReg ());
6756
+ }
6757
+
6758
+ // / Find chains of accumulations, likely linearized by reassocation pass,
6759
+ // / that can be rewritten as a tree for increased ILP.
6760
+ static bool getAccumulatorReassociationPatterns (MachineInstr &Root,
6761
+ SmallVectorImpl<unsigned > &Patterns) {
6762
+ // find a chain of depth 4, which would make it profitable to rewrite
6763
+ // as a tree. This pattern should be applied recursively in case we
6764
+ // have a longer chain.
6765
+ if (!EnableAccReassociation)
6766
+ return false ;
6767
+
6768
+ unsigned Opc = Root.getOpcode ();
6769
+ if (!isAccumulationOpcode (Opc))
6770
+ return false ;
6771
+
6772
+ // Verify that this is the end of the chain.
6773
+ MachineBasicBlock &MBB = *Root.getParent ();
6774
+ MachineRegisterInfo &MRI = MBB.getParent ()->getRegInfo ();
6775
+ if (!MRI.hasOneNonDBGUser (Root.getOperand (0 ).getReg ()))
6776
+ return false ;
6777
+
6778
+ auto User = MRI.use_instr_begin (Root.getOperand (0 ).getReg ());
6779
+ if (User->getOpcode () == Opc)
6780
+ return false ;
6781
+
6782
+ // Walk up the use chain and collect the reduction chain.
6783
+ SmallVector<Register, 32 > Chain;
6784
+ getAccumulatorChain (&Root, MBB, MRI, Chain);
6785
+
6786
+ // Reject chains which are too short to be worth modifying.
6787
+ if (Chain.size () < MinAccumulatorDepth)
6788
+ return false ;
6789
+
6790
+ // Check if the MBB this instruction is a part of contains any other chains. If so, don't apply it.
6791
+ SmallSet<Register, 32 > ReductionChain (Chain.begin (), Chain.end ());
6792
+ for (const auto &I : MBB) {
6793
+ if (I.getOpcode () == Opc && !ReductionChain.contains (I.getOperand (0 ).getReg ()))
6794
+ return false ;
6795
+ }
6796
+
6797
+ typedef AArch64MachineCombinerPattern MCP;
6798
+ Patterns.push_back (MCP::ACC_CHAIN);
6799
+ return true ;
6800
+ }
6801
+
6677
6802
// / Floating-Point Support
6678
6803
6679
6804
// / Find instructions that can be turned into madd.
@@ -7061,6 +7186,7 @@ AArch64InstrInfo::getCombinerObjective(unsigned Pattern) const {
7061
7186
switch (Pattern) {
7062
7187
case AArch64MachineCombinerPattern::SUBADD_OP1:
7063
7188
case AArch64MachineCombinerPattern::SUBADD_OP2:
7189
+ case AArch64MachineCombinerPattern::ACC_CHAIN:
7064
7190
return CombinerObjective::MustReduceDepth;
7065
7191
default :
7066
7192
return TargetInstrInfo::getCombinerObjective (Pattern);
@@ -7078,6 +7204,8 @@ bool AArch64InstrInfo::getMachineCombinerPatterns(
7078
7204
// Integer patterns
7079
7205
if (getMaddPatterns (Root, Patterns))
7080
7206
return true ;
7207
+ if (getAccumulatorReassociationPatterns (Root, Patterns))
7208
+ return true ;
7081
7209
// Floating point patterns
7082
7210
if (getFMULPatterns (Root, Patterns))
7083
7211
return true ;
@@ -7436,6 +7564,72 @@ genSubAdd2SubSub(MachineFunction &MF, MachineRegisterInfo &MRI,
7436
7564
DelInstrs.push_back (&Root);
7437
7565
}
7438
7566
7567
+ static unsigned int getReduceOpCodeForAccumulator (unsigned int AccumulatorOpCode) {
7568
+ switch (AccumulatorOpCode) {
7569
+ case AArch64::UABALB_ZZZ_D:
7570
+ return AArch64::ADD_ZZZ_D;
7571
+ case AArch64::UABALB_ZZZ_H:
7572
+ return AArch64::ADD_ZZZ_H;
7573
+ case AArch64::UABALB_ZZZ_S:
7574
+ return AArch64::ADD_ZZZ_S;
7575
+ case AArch64::UABALT_ZZZ_D:
7576
+ return AArch64::ADD_ZZZ_D;
7577
+ case AArch64::UABALT_ZZZ_H:
7578
+ return AArch64::ADD_ZZZ_H;
7579
+ case AArch64::UABALT_ZZZ_S:
7580
+ return AArch64::ADD_ZZZ_S;
7581
+ case AArch64::UABALv16i8_v8i16:
7582
+ return AArch64::ADDv8i16;
7583
+ case AArch64::UABALv2i32_v2i64:
7584
+ return AArch64::ADDv2i64;
7585
+ case AArch64::UABALv4i16_v4i32:
7586
+ return AArch64::ADDv4i32;
7587
+ case AArch64::UABALv4i32_v2i64:
7588
+ return AArch64::ADDv2i64;
7589
+ case AArch64::UABALv8i16_v4i32:
7590
+ return AArch64::ADDv4i32;
7591
+ case AArch64::UABALv8i8_v8i16:
7592
+ return AArch64::ADDv8i16;
7593
+ default :
7594
+ llvm_unreachable (" Unknown accumulator opcode" );
7595
+ }
7596
+ }
7597
+
7598
+ // Reduce branches of the accumulator tree by adding them together.
7599
+ static void reduceAccumulatorTree (SmallVectorImpl<Register> &RegistersToReduce, SmallVectorImpl<MachineInstr *> &InsInstrs,
7600
+ MachineFunction &MF, MachineInstr &Root, MachineRegisterInfo &MRI,
7601
+ DenseMap<unsigned , unsigned > &InstrIdxForVirtReg, Register ResultReg) {
7602
+ const TargetInstrInfo *TII = MF.getSubtarget ().getInstrInfo ();
7603
+ SmallVector<Register, 8 > NewRegs;
7604
+ for (unsigned int i = 1 ; i <= (RegistersToReduce.size () / 2 ); i+=2 ) {
7605
+ auto RHS = RegistersToReduce[i - 1 ];
7606
+ auto LHS = RegistersToReduce[i];
7607
+ Register Dest;
7608
+ // If we are reducing 2 registers, reuse the original result register.
7609
+ if (RegistersToReduce.size () == 2 )
7610
+ Dest = ResultReg;
7611
+ // Otherwise, create a new virtual register to hold the partial sum.
7612
+ else {
7613
+ auto NewVR = MRI.createVirtualRegister (MRI.getRegClass (Root.getOperand (0 ).getReg ()));
7614
+ Dest = NewVR;
7615
+ NewRegs.push_back (Dest);
7616
+ InstrIdxForVirtReg.insert (std::make_pair (Dest, InsInstrs.size ()));
7617
+ }
7618
+
7619
+ // Create the new add instruction.
7620
+ MachineInstrBuilder MIB = BuildMI (MF, MIMetadata (Root), TII->get (getReduceOpCodeForAccumulator (Root.getOpcode ())), Dest).addReg (RHS, getKillRegState (true )).addReg (LHS, getKillRegState (true ));
7621
+ // Copy any flags needed from the original instruction.
7622
+ MIB->setFlags (Root.getFlags ());
7623
+ InsInstrs.push_back (MIB);
7624
+ }
7625
+
7626
+ // If the number of registers to reduce is odd, add the reminaing register to the vector of registers to reduce.
7627
+ if (RegistersToReduce.size () % 2 != 0 )
7628
+ NewRegs.push_back (RegistersToReduce[RegistersToReduce.size () - 1 ]);
7629
+
7630
+ RegistersToReduce = NewRegs;
7631
+ }
7632
+
7439
7633
// / When getMachineCombinerPatterns() finds potential patterns,
7440
7634
// / this function generates the instructions that could replace the
7441
7635
// / original code sequence
@@ -7671,7 +7865,53 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
7671
7865
MUL = genMaddR (MF, MRI, TII, Root, InsInstrs, 1 , Opc, NewVR, RC);
7672
7866
break ;
7673
7867
}
7868
+ case AArch64MachineCombinerPattern::ACC_CHAIN: {
7869
+ SmallVector<Register, 32 > ChainRegs;
7870
+ getAccumulatorChain (&Root, MBB, MRI, ChainRegs);
7871
+
7872
+ unsigned int Depth = ChainRegs.size ();
7873
+ assert (MaxAccumulatorWidth > 1 && " Max accumulator width set to illegal value" );
7874
+ unsigned int MaxWidth = Log2_32 (Depth) < MaxAccumulatorWidth ? Log2_32 (Depth) : MaxAccumulatorWidth;
7875
+
7876
+ // Walk down the chain and rewrite it as a tree.
7877
+ for (auto IndexedReg : llvm::enumerate (llvm::reverse (ChainRegs))) {
7878
+ // No need to rewrite the first node, it is already perfect as it is.
7879
+ if (IndexedReg.index () == 0 )
7880
+ continue ;
7881
+
7882
+ MachineInstr *Instr = MRI.getUniqueVRegDef (IndexedReg.value ());
7883
+ MachineInstrBuilder MIB;
7884
+ Register AccReg;
7885
+ if (IndexedReg.index () < MaxWidth) {
7886
+ // Now we need to create new instructions for the first row.
7887
+ AccReg = Instr->getOperand (0 ).getReg ();
7888
+ MIB = BuildMI (MF, MIMetadata (*Instr), TII->get (MRI.getUniqueVRegDef (ChainRegs.back ())->getOpcode ()), AccReg).addReg (Instr->getOperand (2 ).getReg (), getKillRegState (Instr->getOperand (2 ).isKill ())).addReg (Instr->getOperand (3 ).getReg (), getKillRegState (Instr->getOperand (3 ).isKill ()));
7889
+ } else {
7890
+ // For the remaining cases, we need ot use an output register of one of the newly inserted instuctions as operand 1
7891
+ AccReg = Instr->getOperand (0 ).getReg () == Root.getOperand (0 ).getReg () ? MRI.createVirtualRegister (MRI.getRegClass (Root.getOperand (0 ).getReg ())) : Instr->getOperand (0 ).getReg ();
7892
+ assert (IndexedReg.index () - MaxWidth >= 0 );
7893
+ auto AccumulatorInput = ChainRegs[Depth - (IndexedReg.index () - MaxWidth) - 1 ];
7894
+ MIB = BuildMI (MF, MIMetadata (*Instr), TII->get (Instr->getOpcode ()), AccReg).addReg (AccumulatorInput, getKillRegState (true )).addReg (Instr->getOperand (2 ).getReg (), getKillRegState (Instr->getOperand (2 ).isKill ())).addReg (Instr->getOperand (3 ).getReg (), getKillRegState (Instr->getOperand (3 ).isKill ()));
7895
+ }
7896
+
7897
+ MIB->setFlags (Instr->getFlags ());
7898
+ InstrIdxForVirtReg.insert (std::make_pair (AccReg, InsInstrs.size ()));
7899
+ InsInstrs.push_back (MIB);
7900
+ DelInstrs.push_back (Instr);
7901
+ }
7902
+
7903
+ SmallVector<Register, 8 > RegistersToReduce;
7904
+ for (int i = (InsInstrs.size () - MaxWidth); i < InsInstrs.size (); ++i) {
7905
+ auto Reg = InsInstrs[i]->getOperand (0 ).getReg ();
7906
+ RegistersToReduce.push_back (Reg);
7907
+ }
7674
7908
7909
+ while (RegistersToReduce.size () > 1 )
7910
+ reduceAccumulatorTree (RegistersToReduce, InsInstrs, MF, Root, MRI, InstrIdxForVirtReg, Root.getOperand (0 ).getReg ());
7911
+
7912
+ // We don't want to break, we handle setting flags and adding Root to DelInstrs from here.
7913
+ return ;
7914
+ }
7675
7915
case AArch64MachineCombinerPattern::MULADDv8i8_OP1:
7676
7916
Opc = AArch64::MLAv8i8;
7677
7917
RC = &AArch64::FPR64RegClass;
0 commit comments