20
20
#include " Utils/AArch64BaseInfo.h"
21
21
#include " llvm/ADT/ArrayRef.h"
22
22
#include " llvm/ADT/STLExtras.h"
23
+ #include " llvm/ADT/SmallSet.h"
23
24
#include " llvm/ADT/SmallVector.h"
24
25
#include " llvm/CodeGen/LivePhysRegs.h"
25
26
#include " llvm/CodeGen/MachineBasicBlock.h"
@@ -78,6 +79,19 @@ static cl::opt<unsigned>
78
79
BDisplacementBits (" aarch64-b-offset-bits" , cl::Hidden, cl::init(26 ),
79
80
cl::desc(" Restrict range of B instructions (DEBUG)" ));
80
81
82
+ static cl::opt<bool > EnableAccReassociation (
83
+ " aarch64-acc-reassoc" , cl::Hidden, cl::init(true ),
84
+ cl::desc(" Enable reassociation of accumulation chains" ));
85
+
86
+ static cl::opt<unsigned int >
87
+ MinAccumulatorDepth (" aarch64-acc-min-depth" , cl::Hidden, cl::init(8 ),
88
+ cl::desc(" Minimum length of accumulator chains "
89
+ " required for the optimization to kick in" ));
90
+
91
+ static cl::opt<unsigned int > MaxAccumulatorWidth (
92
+ " aarch64-acc-max-width" , cl::Hidden, cl::init(3 ),
93
+ cl::desc(" Maximum number of branches in the accumulator tree" ));
94
+
81
95
AArch64InstrInfo::AArch64InstrInfo (const AArch64Subtarget &STI)
82
96
: AArch64GenInstrInfo(AArch64::ADJCALLSTACKDOWN, AArch64::ADJCALLSTACKUP,
83
97
AArch64::CATCHRET),
@@ -6674,6 +6688,127 @@ static bool getMaddPatterns(MachineInstr &Root,
6674
6688
}
6675
6689
return Found;
6676
6690
}
6691
+
6692
+ static bool isAccumulationOpcode (unsigned Opcode) {
6693
+ switch (Opcode) {
6694
+ default :
6695
+ break ;
6696
+ case AArch64::UABALB_ZZZ_D:
6697
+ case AArch64::UABALB_ZZZ_H:
6698
+ case AArch64::UABALB_ZZZ_S:
6699
+ case AArch64::UABALT_ZZZ_D:
6700
+ case AArch64::UABALT_ZZZ_H:
6701
+ case AArch64::UABALT_ZZZ_S:
6702
+ case AArch64::UABALv16i8_v8i16:
6703
+ case AArch64::UABALv2i32_v2i64:
6704
+ case AArch64::UABALv4i16_v4i32:
6705
+ case AArch64::UABALv4i32_v2i64:
6706
+ case AArch64::UABALv8i16_v4i32:
6707
+ case AArch64::UABALv8i8_v8i16:
6708
+ return true ;
6709
+ }
6710
+
6711
+ return false ;
6712
+ }
6713
+
6714
+ static unsigned getAccumulationStartOpCode (unsigned AccumulationOpcode) {
6715
+ switch (AccumulationOpcode) {
6716
+ default :
6717
+ llvm_unreachable (" Unknown accumulator opcode" );
6718
+ case AArch64::UABALB_ZZZ_D:
6719
+ return AArch64::UABDLB_ZZZ_D;
6720
+ case AArch64::UABALB_ZZZ_H:
6721
+ return AArch64::UABDLB_ZZZ_H;
6722
+ case AArch64::UABALB_ZZZ_S:
6723
+ return AArch64::UABDLB_ZZZ_S;
6724
+ case AArch64::UABALT_ZZZ_D:
6725
+ return AArch64::UABDLT_ZZZ_D;
6726
+ case AArch64::UABALT_ZZZ_H:
6727
+ return AArch64::UABDLT_ZZZ_H;
6728
+ case AArch64::UABALT_ZZZ_S:
6729
+ return AArch64::UABDLT_ZZZ_S;
6730
+ case AArch64::UABALv16i8_v8i16:
6731
+ return AArch64::UABDLv16i8_v8i16;
6732
+ case AArch64::UABALv2i32_v2i64:
6733
+ return AArch64::UABDLv2i32_v2i64;
6734
+ case AArch64::UABALv4i16_v4i32:
6735
+ return AArch64::UABDLv4i16_v4i32;
6736
+ case AArch64::UABALv4i32_v2i64:
6737
+ return AArch64::UABDLv4i32_v2i64;
6738
+ case AArch64::UABALv8i16_v4i32:
6739
+ return AArch64::UABDLv8i16_v4i32;
6740
+ case AArch64::UABALv8i8_v8i16:
6741
+ return AArch64::UABDLv8i8_v8i16;
6742
+ }
6743
+ }
6744
+
6745
+ static void getAccumulatorChain (MachineInstr *CurrentInstr,
6746
+ MachineBasicBlock &MBB,
6747
+ MachineRegisterInfo &MRI,
6748
+ SmallVectorImpl<Register> &Chain) {
6749
+ // Walk up the chain of accumulation instructions and collect them in the
6750
+ // vector.
6751
+ unsigned AccumulatorOpcode = CurrentInstr->getOpcode ();
6752
+ unsigned ChainStartOpCode = getAccumulationStartOpCode (AccumulatorOpcode);
6753
+ while (CurrentInstr &&
6754
+ (canCombine (MBB, CurrentInstr->getOperand (1 ), AccumulatorOpcode) ||
6755
+ canCombine (MBB, CurrentInstr->getOperand (1 ), ChainStartOpCode))) {
6756
+ Chain.push_back (CurrentInstr->getOperand (0 ).getReg ());
6757
+ CurrentInstr = MRI.getUniqueVRegDef (CurrentInstr->getOperand (1 ).getReg ());
6758
+ }
6759
+
6760
+ // Add the instruction at the top of the chain.
6761
+ if (CurrentInstr->getOpcode () == ChainStartOpCode)
6762
+ Chain.push_back (CurrentInstr->getOperand (0 ).getReg ());
6763
+ }
6764
+
6765
+ // / Find chains of accumulations, likely linearized by reassocation pass,
6766
+ // / that can be rewritten as a tree for increased ILP.
6767
+ static bool
6768
+ getAccumulatorReassociationPatterns (MachineInstr &Root,
6769
+ SmallVectorImpl<unsigned > &Patterns) {
6770
+ // find a chain of depth 4, which would make it profitable to rewrite
6771
+ // as a tree. This pattern should be applied recursively in case we
6772
+ // have a longer chain.
6773
+ if (!EnableAccReassociation)
6774
+ return false ;
6775
+
6776
+ unsigned Opc = Root.getOpcode ();
6777
+ if (!isAccumulationOpcode (Opc))
6778
+ return false ;
6779
+
6780
+ // Verify that this is the end of the chain.
6781
+ MachineBasicBlock &MBB = *Root.getParent ();
6782
+ MachineRegisterInfo &MRI = MBB.getParent ()->getRegInfo ();
6783
+ if (!MRI.hasOneNonDBGUser (Root.getOperand (0 ).getReg ()))
6784
+ return false ;
6785
+
6786
+ auto User = MRI.use_instr_begin (Root.getOperand (0 ).getReg ());
6787
+ if (User->getOpcode () == Opc)
6788
+ return false ;
6789
+
6790
+ // Walk up the use chain and collect the reduction chain.
6791
+ SmallVector<Register, 32 > Chain;
6792
+ getAccumulatorChain (&Root, MBB, MRI, Chain);
6793
+
6794
+ // Reject chains which are too short to be worth modifying.
6795
+ if (Chain.size () < MinAccumulatorDepth)
6796
+ return false ;
6797
+
6798
+ // Check if the MBB this instruction is a part of contains any other chains.
6799
+ // If so, don't apply it.
6800
+ SmallSet<Register, 32 > ReductionChain (Chain.begin (), Chain.end ());
6801
+ for (const auto &I : MBB) {
6802
+ if (I.getOpcode () == Opc &&
6803
+ !ReductionChain.contains (I.getOperand (0 ).getReg ()))
6804
+ return false ;
6805
+ }
6806
+
6807
+ typedef AArch64MachineCombinerPattern MCP;
6808
+ Patterns.push_back (MCP::ACC_CHAIN);
6809
+ return true ;
6810
+ }
6811
+
6677
6812
// / Floating-Point Support
6678
6813
6679
6814
// / Find instructions that can be turned into madd.
@@ -7061,6 +7196,7 @@ AArch64InstrInfo::getCombinerObjective(unsigned Pattern) const {
7061
7196
switch (Pattern) {
7062
7197
case AArch64MachineCombinerPattern::SUBADD_OP1:
7063
7198
case AArch64MachineCombinerPattern::SUBADD_OP2:
7199
+ case AArch64MachineCombinerPattern::ACC_CHAIN:
7064
7200
return CombinerObjective::MustReduceDepth;
7065
7201
default :
7066
7202
return TargetInstrInfo::getCombinerObjective (Pattern);
@@ -7078,6 +7214,8 @@ bool AArch64InstrInfo::getMachineCombinerPatterns(
7078
7214
// Integer patterns
7079
7215
if (getMaddPatterns (Root, Patterns))
7080
7216
return true ;
7217
+ if (getAccumulatorReassociationPatterns (Root, Patterns))
7218
+ return true ;
7081
7219
// Floating point patterns
7082
7220
if (getFMULPatterns (Root, Patterns))
7083
7221
return true ;
@@ -7436,6 +7574,81 @@ genSubAdd2SubSub(MachineFunction &MF, MachineRegisterInfo &MRI,
7436
7574
DelInstrs.push_back (&Root);
7437
7575
}
7438
7576
7577
+ static unsigned int
7578
+ getReduceOpCodeForAccumulator (unsigned int AccumulatorOpCode) {
7579
+ switch (AccumulatorOpCode) {
7580
+ case AArch64::UABALB_ZZZ_D:
7581
+ return AArch64::ADD_ZZZ_D;
7582
+ case AArch64::UABALB_ZZZ_H:
7583
+ return AArch64::ADD_ZZZ_H;
7584
+ case AArch64::UABALB_ZZZ_S:
7585
+ return AArch64::ADD_ZZZ_S;
7586
+ case AArch64::UABALT_ZZZ_D:
7587
+ return AArch64::ADD_ZZZ_D;
7588
+ case AArch64::UABALT_ZZZ_H:
7589
+ return AArch64::ADD_ZZZ_H;
7590
+ case AArch64::UABALT_ZZZ_S:
7591
+ return AArch64::ADD_ZZZ_S;
7592
+ case AArch64::UABALv16i8_v8i16:
7593
+ return AArch64::ADDv8i16;
7594
+ case AArch64::UABALv2i32_v2i64:
7595
+ return AArch64::ADDv2i64;
7596
+ case AArch64::UABALv4i16_v4i32:
7597
+ return AArch64::ADDv4i32;
7598
+ case AArch64::UABALv4i32_v2i64:
7599
+ return AArch64::ADDv2i64;
7600
+ case AArch64::UABALv8i16_v4i32:
7601
+ return AArch64::ADDv4i32;
7602
+ case AArch64::UABALv8i8_v8i16:
7603
+ return AArch64::ADDv8i16;
7604
+ default :
7605
+ llvm_unreachable (" Unknown accumulator opcode" );
7606
+ }
7607
+ }
7608
+
7609
+ // Reduce branches of the accumulator tree by adding them together.
7610
+ static void reduceAccumulatorTree (
7611
+ SmallVectorImpl<Register> &RegistersToReduce,
7612
+ SmallVectorImpl<MachineInstr *> &InsInstrs, MachineFunction &MF,
7613
+ MachineInstr &Root, MachineRegisterInfo &MRI,
7614
+ DenseMap<unsigned , unsigned > &InstrIdxForVirtReg, Register ResultReg) {
7615
+ const TargetInstrInfo *TII = MF.getSubtarget ().getInstrInfo ();
7616
+ SmallVector<Register, 8 > NewRegs;
7617
+ for (unsigned int i = 1 ; i <= (RegistersToReduce.size () / 2 ); i += 2 ) {
7618
+ auto RHS = RegistersToReduce[i - 1 ];
7619
+ auto LHS = RegistersToReduce[i];
7620
+ Register Dest;
7621
+ // If we are reducing 2 registers, reuse the original result register.
7622
+ if (RegistersToReduce.size () == 2 )
7623
+ Dest = ResultReg;
7624
+ // Otherwise, create a new virtual register to hold the partial sum.
7625
+ else {
7626
+ auto NewVR = MRI.createVirtualRegister (
7627
+ MRI.getRegClass (Root.getOperand (0 ).getReg ()));
7628
+ Dest = NewVR;
7629
+ NewRegs.push_back (Dest);
7630
+ InstrIdxForVirtReg.insert (std::make_pair (Dest, InsInstrs.size ()));
7631
+ }
7632
+
7633
+ // Create the new add instruction.
7634
+ MachineInstrBuilder MIB =
7635
+ BuildMI (MF, MIMetadata (Root),
7636
+ TII->get (getReduceOpCodeForAccumulator (Root.getOpcode ())), Dest)
7637
+ .addReg (RHS, getKillRegState (true ))
7638
+ .addReg (LHS, getKillRegState (true ));
7639
+ // Copy any flags needed from the original instruction.
7640
+ MIB->setFlags (Root.getFlags ());
7641
+ InsInstrs.push_back (MIB);
7642
+ }
7643
+
7644
+ // If the number of registers to reduce is odd, add the reminaing register to
7645
+ // the vector of registers to reduce.
7646
+ if (RegistersToReduce.size () % 2 != 0 )
7647
+ NewRegs.push_back (RegistersToReduce[RegistersToReduce.size () - 1 ]);
7648
+
7649
+ RegistersToReduce = NewRegs;
7650
+ }
7651
+
7439
7652
// / When getMachineCombinerPatterns() finds potential patterns,
7440
7653
// / this function generates the instructions that could replace the
7441
7654
// / original code sequence
@@ -7671,7 +7884,76 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
7671
7884
MUL = genMaddR (MF, MRI, TII, Root, InsInstrs, 1 , Opc, NewVR, RC);
7672
7885
break ;
7673
7886
}
7887
+ case AArch64MachineCombinerPattern::ACC_CHAIN: {
7888
+ SmallVector<Register, 32 > ChainRegs;
7889
+ getAccumulatorChain (&Root, MBB, MRI, ChainRegs);
7890
+
7891
+ unsigned int Depth = ChainRegs.size ();
7892
+ assert (MaxAccumulatorWidth > 1 &&
7893
+ " Max accumulator width set to illegal value" );
7894
+ unsigned int MaxWidth = Log2_32 (Depth) < MaxAccumulatorWidth
7895
+ ? Log2_32 (Depth)
7896
+ : MaxAccumulatorWidth;
7897
+
7898
+ // Walk down the chain and rewrite it as a tree.
7899
+ for (auto IndexedReg : llvm::enumerate (llvm::reverse (ChainRegs))) {
7900
+ // No need to rewrite the first node, it is already perfect as it is.
7901
+ if (IndexedReg.index () == 0 )
7902
+ continue ;
7903
+
7904
+ MachineInstr *Instr = MRI.getUniqueVRegDef (IndexedReg.value ());
7905
+ MachineInstrBuilder MIB;
7906
+ Register AccReg;
7907
+ if (IndexedReg.index () < MaxWidth) {
7908
+ // Now we need to create new instructions for the first row.
7909
+ AccReg = Instr->getOperand (0 ).getReg ();
7910
+ MIB = BuildMI (
7911
+ MF, MIMetadata (*Instr),
7912
+ TII->get (MRI.getUniqueVRegDef (ChainRegs.back ())->getOpcode ()),
7913
+ AccReg)
7914
+ .addReg (Instr->getOperand (2 ).getReg (),
7915
+ getKillRegState (Instr->getOperand (2 ).isKill ()))
7916
+ .addReg (Instr->getOperand (3 ).getReg (),
7917
+ getKillRegState (Instr->getOperand (3 ).isKill ()));
7918
+ } else {
7919
+ // For the remaining cases, we need ot use an output register of one of
7920
+ // the newly inserted instuctions as operand 1
7921
+ AccReg = Instr->getOperand (0 ).getReg () == Root.getOperand (0 ).getReg ()
7922
+ ? MRI.createVirtualRegister (
7923
+ MRI.getRegClass (Root.getOperand (0 ).getReg ()))
7924
+ : Instr->getOperand (0 ).getReg ();
7925
+ assert (IndexedReg.index () - MaxWidth >= 0 );
7926
+ auto AccumulatorInput =
7927
+ ChainRegs[Depth - (IndexedReg.index () - MaxWidth) - 1 ];
7928
+ MIB = BuildMI (MF, MIMetadata (*Instr), TII->get (Instr->getOpcode ()),
7929
+ AccReg)
7930
+ .addReg (AccumulatorInput, getKillRegState (true ))
7931
+ .addReg (Instr->getOperand (2 ).getReg (),
7932
+ getKillRegState (Instr->getOperand (2 ).isKill ()))
7933
+ .addReg (Instr->getOperand (3 ).getReg (),
7934
+ getKillRegState (Instr->getOperand (3 ).isKill ()));
7935
+ }
7936
+
7937
+ MIB->setFlags (Instr->getFlags ());
7938
+ InstrIdxForVirtReg.insert (std::make_pair (AccReg, InsInstrs.size ()));
7939
+ InsInstrs.push_back (MIB);
7940
+ DelInstrs.push_back (Instr);
7941
+ }
7942
+
7943
+ SmallVector<Register, 8 > RegistersToReduce;
7944
+ for (int i = (InsInstrs.size () - MaxWidth); i < InsInstrs.size (); ++i) {
7945
+ auto Reg = InsInstrs[i]->getOperand (0 ).getReg ();
7946
+ RegistersToReduce.push_back (Reg);
7947
+ }
7948
+
7949
+ while (RegistersToReduce.size () > 1 )
7950
+ reduceAccumulatorTree (RegistersToReduce, InsInstrs, MF, Root, MRI,
7951
+ InstrIdxForVirtReg, Root.getOperand (0 ).getReg ());
7674
7952
7953
+ // We don't want to break, we handle setting flags and adding Root to
7954
+ // DelInstrs from here.
7955
+ return ;
7956
+ }
7675
7957
case AArch64MachineCombinerPattern::MULADDv8i8_OP1:
7676
7958
Opc = AArch64::MLAv8i8;
7677
7959
RC = &AArch64::FPR64RegClass;
0 commit comments