11
11
// ===----------------------------------------------------------------------===//
12
12
13
13
#include " llvm/CodeGen/TargetInstrInfo.h"
14
+ #include " llvm/ADT/SmallSet.h"
14
15
#include " llvm/ADT/StringExtras.h"
15
16
#include " llvm/BinaryFormat/Dwarf.h"
16
17
#include " llvm/CodeGen/MachineCombinerPattern.h"
@@ -42,6 +43,19 @@ static cl::opt<bool> DisableHazardRecognizer(
42
43
" disable-sched-hazard" , cl::Hidden, cl::init(false ),
43
44
cl::desc(" Disable hazard detection during preRA scheduling" ));
44
45
46
+ static cl::opt<bool > EnableAccReassociation (
47
+ " acc-reassoc" , cl::Hidden, cl::init(true ),
48
+ cl::desc(" Enable reassociation of accumulation chains" ));
49
+
50
+ static cl::opt<unsigned int >
51
+ MinAccumulatorDepth (" acc-min-depth" , cl::Hidden, cl::init(8 ),
52
+ cl::desc(" Minimum length of accumulator chains "
53
+ " required for the optimization to kick in" ));
54
+
55
+ static cl::opt<unsigned int > MaxAccumulatorWidth (
56
+ " acc-max-width" , cl::Hidden, cl::init(3 ),
57
+ cl::desc(" Maximum number of branches in the accumulator tree" ));
58
+
45
59
TargetInstrInfo::~TargetInstrInfo () = default ;
46
60
47
61
const TargetRegisterClass*
@@ -899,6 +913,158 @@ bool TargetInstrInfo::isReassociationCandidate(const MachineInstr &Inst,
899
913
hasReassociableSibling (Inst, Commuted);
900
914
}
901
915
916
+ // Utility routine that checks if \param MO is defined by an
917
+ // \param CombineOpc instruction in the basic block \param MBB.
918
+ // If \param CombineOpc is not provided, the OpCode check will
919
+ // be skipped.
920
+ static bool canCombine (MachineBasicBlock &MBB, MachineOperand &MO,
921
+ unsigned CombineOpc = 0 ) {
922
+ MachineRegisterInfo &MRI = MBB.getParent ()->getRegInfo ();
923
+ MachineInstr *MI = nullptr ;
924
+
925
+ if (MO.isReg () && MO.getReg ().isVirtual ())
926
+ MI = MRI.getUniqueVRegDef (MO.getReg ());
927
+ // And it needs to be in the trace (otherwise, it won't have a depth).
928
+ if (!MI || MI->getParent () != &MBB ||
929
+ ((unsigned )MI->getOpcode () != CombineOpc && CombineOpc != 0 ))
930
+ return false ;
931
+ // Must only used by the user we combine with.
932
+ if (!MRI.hasOneNonDBGUse (MI->getOperand (0 ).getReg ()))
933
+ return false ;
934
+
935
+ return true ;
936
+ }
937
+
938
+ // A chain of accumulation instructions will be selected IFF:
939
+ // 1. All the accumulation instructions in the chain have the same opcode,
940
+ // besides the first that has a slightly different opcode because it does
941
+ // not perform the accumulation, just defines it.
942
+ // 2. All the instructions in the chain are combinable (have a single use
943
+ // which itself is part of the chain).
944
+ // 3. Meets the required minimum length.
945
+ void TargetInstrInfo::getAccumulatorChain (
946
+ MachineInstr *CurrentInstr, SmallVectorImpl<Register> &Chain) const {
947
+ // Walk up the chain of accumulation instructions and collect them in the
948
+ // vector.
949
+ MachineBasicBlock &MBB = *CurrentInstr->getParent ();
950
+ const MachineRegisterInfo &MRI = MBB.getParent ()->getRegInfo ();
951
+ unsigned AccumulatorOpcode = CurrentInstr->getOpcode ();
952
+ std::optional<unsigned > ChainStartOpCode =
953
+ getAccumulationStartOpcode (AccumulatorOpcode);
954
+
955
+ if (!ChainStartOpCode.has_value ())
956
+ return ;
957
+
958
+ // Push the first accumulator result to the start of the chain.
959
+ Chain.push_back (CurrentInstr->getOperand (0 ).getReg ());
960
+
961
+ // Collect the accumulator input register from all instructions in the chain.
962
+ while (CurrentInstr &&
963
+ canCombine (MBB, CurrentInstr->getOperand (1 ), AccumulatorOpcode)) {
964
+ Chain.push_back (CurrentInstr->getOperand (1 ).getReg ());
965
+ CurrentInstr = MRI.getUniqueVRegDef (CurrentInstr->getOperand (1 ).getReg ());
966
+ }
967
+
968
+ // Add the instruction at the top of the chain.
969
+ if (CurrentInstr->getOpcode () == AccumulatorOpcode &&
970
+ canCombine (MBB, CurrentInstr->getOperand (1 )))
971
+ Chain.push_back (CurrentInstr->getOperand (1 ).getReg ());
972
+ }
973
+
974
+ // / Find chains of accumulations that can be rewritten as a tree for increased
975
+ // / ILP.
976
+ bool TargetInstrInfo::getAccumulatorReassociationPatterns (
977
+ MachineInstr &Root, SmallVectorImpl<unsigned > &Patterns) const {
978
+ if (!EnableAccReassociation)
979
+ return false ;
980
+
981
+ unsigned Opc = Root.getOpcode ();
982
+ if (!isAccumulationOpcode (Opc))
983
+ return false ;
984
+
985
+ // Verify that this is the end of the chain.
986
+ MachineBasicBlock &MBB = *Root.getParent ();
987
+ MachineRegisterInfo &MRI = MBB.getParent ()->getRegInfo ();
988
+ if (!MRI.hasOneNonDBGUser (Root.getOperand (0 ).getReg ()))
989
+ return false ;
990
+
991
+ auto User = MRI.use_instr_begin (Root.getOperand (0 ).getReg ());
992
+ if (User->getOpcode () == Opc)
993
+ return false ;
994
+
995
+ // Walk up the use chain and collect the reduction chain.
996
+ SmallVector<Register, 32 > Chain;
997
+ getAccumulatorChain (&Root, Chain);
998
+
999
+ // Reject chains which are too short to be worth modifying.
1000
+ if (Chain.size () < MinAccumulatorDepth)
1001
+ return false ;
1002
+
1003
+ // Check if the MBB this instruction is a part of contains any other chains.
1004
+ // If so, don't apply it.
1005
+ SmallSet<Register, 32 > ReductionChain (Chain.begin (), Chain.end ());
1006
+ for (const auto &I : MBB) {
1007
+ if (I.getOpcode () == Opc &&
1008
+ !ReductionChain.contains (I.getOperand (0 ).getReg ()))
1009
+ return false ;
1010
+ }
1011
+
1012
+ Patterns.push_back (MachineCombinerPattern::ACC_CHAIN);
1013
+ return true ;
1014
+ }
1015
+
1016
+ // Reduce branches of the accumulator tree by adding them together.
1017
+ void TargetInstrInfo::reduceAccumulatorTree (
1018
+ SmallVectorImpl<Register> &RegistersToReduce,
1019
+ SmallVectorImpl<MachineInstr *> &InsInstrs, MachineFunction &MF,
1020
+ MachineInstr &Root, MachineRegisterInfo &MRI,
1021
+ DenseMap<unsigned , unsigned > &InstrIdxForVirtReg,
1022
+ Register ResultReg) const {
1023
+ const TargetInstrInfo *TII = MF.getSubtarget ().getInstrInfo ();
1024
+ SmallVector<Register, 8 > NewRegs;
1025
+
1026
+ // Get the opcode for the reduction instruction we will need to build.
1027
+ // If for some reason it is not defined, early exit and don't apply this.
1028
+ std::optional<unsigned > ReduceOpCode =
1029
+ getReduceOpcodeForAccumulator (Root.getOpcode ());
1030
+
1031
+ if (!ReduceOpCode.value ())
1032
+ return ;
1033
+
1034
+ for (unsigned int i = 1 ; i <= (RegistersToReduce.size () / 2 ); i += 2 ) {
1035
+ auto RHS = RegistersToReduce[i - 1 ];
1036
+ auto LHS = RegistersToReduce[i];
1037
+ Register Dest;
1038
+ // If we are reducing 2 registers, reuse the original result register.
1039
+ if (RegistersToReduce.size () == 2 )
1040
+ Dest = ResultReg;
1041
+ // Otherwise, create a new virtual register to hold the partial sum.
1042
+ else {
1043
+ auto NewVR = MRI.createVirtualRegister (
1044
+ MRI.getRegClass (Root.getOperand (0 ).getReg ()));
1045
+ Dest = NewVR;
1046
+ NewRegs.push_back (Dest);
1047
+ InstrIdxForVirtReg.insert (std::make_pair (Dest, InsInstrs.size ()));
1048
+ }
1049
+
1050
+ // Create the new reduction instruction.
1051
+ MachineInstrBuilder MIB =
1052
+ BuildMI (MF, MIMetadata (Root), TII->get (ReduceOpCode.value ()), Dest)
1053
+ .addReg (RHS, getKillRegState (true ))
1054
+ .addReg (LHS, getKillRegState (true ));
1055
+ // Copy any flags needed from the original instruction.
1056
+ MIB->setFlags (Root.getFlags ());
1057
+ InsInstrs.push_back (MIB);
1058
+ }
1059
+
1060
+ // If the number of registers to reduce is odd, add the reminaing register to
1061
+ // the vector of registers to reduce.
1062
+ if (RegistersToReduce.size () % 2 != 0 )
1063
+ NewRegs.push_back (RegistersToReduce[RegistersToReduce.size () - 1 ]);
1064
+
1065
+ RegistersToReduce = NewRegs;
1066
+ }
1067
+
902
1068
// The concept of the reassociation pass is that these operations can benefit
903
1069
// from this kind of transformation:
904
1070
//
@@ -938,6 +1104,8 @@ bool TargetInstrInfo::getMachineCombinerPatterns(
938
1104
}
939
1105
return true ;
940
1106
}
1107
+ if (getAccumulatorReassociationPatterns (Root, Patterns))
1108
+ return true ;
941
1109
942
1110
return false ;
943
1111
}
@@ -949,7 +1117,12 @@ bool TargetInstrInfo::isThroughputPattern(unsigned Pattern) const {
949
1117
950
1118
CombinerObjective
951
1119
TargetInstrInfo::getCombinerObjective (unsigned Pattern) const {
952
- return CombinerObjective::Default;
1120
+ switch (Pattern) {
1121
+ case MachineCombinerPattern::ACC_CHAIN:
1122
+ return CombinerObjective::MustReduceDepth;
1123
+ default :
1124
+ return CombinerObjective::Default;
1125
+ }
953
1126
}
954
1127
955
1128
std::pair<unsigned , unsigned >
@@ -1252,19 +1425,101 @@ void TargetInstrInfo::genAlternativeCodeSequence(
1252
1425
SmallVectorImpl<MachineInstr *> &DelInstrs,
1253
1426
DenseMap<unsigned , unsigned > &InstIdxForVirtReg) const {
1254
1427
MachineRegisterInfo &MRI = Root.getMF ()->getRegInfo ();
1428
+ MachineBasicBlock &MBB = *Root.getParent ();
1429
+ MachineFunction &MF = *MBB.getParent ();
1430
+ const TargetInstrInfo *TII = MF.getSubtarget ().getInstrInfo ();
1255
1431
1256
- // Select the previous instruction in the sequence based on the input pattern.
1257
- std::array<unsigned , 5 > OperandIndices;
1258
- getReassociateOperandIndices (Root, Pattern, OperandIndices);
1259
- MachineInstr *Prev =
1260
- MRI.getUniqueVRegDef (Root.getOperand (OperandIndices[0 ]).getReg ());
1432
+ switch (Pattern) {
1433
+ case MachineCombinerPattern::REASSOC_AX_BY:
1434
+ case MachineCombinerPattern::REASSOC_AX_YB:
1435
+ case MachineCombinerPattern::REASSOC_XA_BY:
1436
+ case MachineCombinerPattern::REASSOC_XA_YB: {
1437
+ // Select the previous instruction in the sequence based on the input
1438
+ // pattern.
1439
+ std::array<unsigned , 5 > OperandIndices;
1440
+ getReassociateOperandIndices (Root, Pattern, OperandIndices);
1441
+ MachineInstr *Prev =
1442
+ MRI.getUniqueVRegDef (Root.getOperand (OperandIndices[0 ]).getReg ());
1443
+
1444
+ // Don't reassociate if Prev and Root are in different blocks.
1445
+ if (Prev->getParent () != Root.getParent ())
1446
+ return ;
1261
1447
1262
- // Don't reassociate if Prev and Root are in different blocks.
1263
- if (Prev->getParent () != Root.getParent ())
1264
- return ;
1448
+ reassociateOps (Root, *Prev, Pattern, InsInstrs, DelInstrs, OperandIndices,
1449
+ InstIdxForVirtReg);
1450
+ break ;
1451
+ }
1452
+ case MachineCombinerPattern::ACC_CHAIN: {
1453
+ SmallVector<Register, 32 > ChainRegs;
1454
+ getAccumulatorChain (&Root, ChainRegs);
1455
+ unsigned int Depth = ChainRegs.size ();
1456
+ assert (MaxAccumulatorWidth > 1 &&
1457
+ " Max accumulator width set to illegal value" );
1458
+ unsigned int MaxWidth = Log2_32 (Depth) < MaxAccumulatorWidth
1459
+ ? Log2_32 (Depth)
1460
+ : MaxAccumulatorWidth;
1461
+
1462
+ // Walk down the chain and rewrite it as a tree.
1463
+ for (auto IndexedReg : llvm::enumerate (llvm::reverse (ChainRegs))) {
1464
+ // No need to rewrite the first node, it is already perfect as it is.
1465
+ if (IndexedReg.index () == 0 )
1466
+ continue ;
1467
+
1468
+ MachineInstr *Instr = MRI.getUniqueVRegDef (IndexedReg.value ());
1469
+ MachineInstrBuilder MIB;
1470
+ Register AccReg;
1471
+ if (IndexedReg.index () < MaxWidth) {
1472
+ // Now we need to create new instructions for the first row.
1473
+ AccReg = Instr->getOperand (0 ).getReg ();
1474
+ std::optional<unsigned > OpCode =
1475
+ getAccumulationStartOpcode (Root.getOpcode ());
1476
+ assert (OpCode.value () &&
1477
+ " Missing opcode for accumulation instruction." );
1478
+
1479
+ MIB = BuildMI (MF, MIMetadata (*Instr), TII->get (OpCode.value ()), AccReg)
1480
+ .addReg (Instr->getOperand (2 ).getReg (),
1481
+ getKillRegState (Instr->getOperand (2 ).isKill ()))
1482
+ .addReg (Instr->getOperand (3 ).getReg (),
1483
+ getKillRegState (Instr->getOperand (3 ).isKill ()));
1484
+ } else {
1485
+ // For the remaining cases, we need ot use an output register of one of
1486
+ // the newly inserted instuctions as operand 1
1487
+ AccReg = Instr->getOperand (0 ).getReg () == Root.getOperand (0 ).getReg ()
1488
+ ? MRI.createVirtualRegister (
1489
+ MRI.getRegClass (Root.getOperand (0 ).getReg ()))
1490
+ : Instr->getOperand (0 ).getReg ();
1491
+ assert (IndexedReg.index () - MaxWidth >= 0 );
1492
+ auto AccumulatorInput =
1493
+ ChainRegs[Depth - (IndexedReg.index () - MaxWidth) - 1 ];
1494
+ MIB = BuildMI (MF, MIMetadata (*Instr), TII->get (Instr->getOpcode ()),
1495
+ AccReg)
1496
+ .addReg (AccumulatorInput, getKillRegState (true ))
1497
+ .addReg (Instr->getOperand (2 ).getReg (),
1498
+ getKillRegState (Instr->getOperand (2 ).isKill ()))
1499
+ .addReg (Instr->getOperand (3 ).getReg (),
1500
+ getKillRegState (Instr->getOperand (3 ).isKill ()));
1501
+ }
1265
1502
1266
- reassociateOps (Root, *Prev, Pattern, InsInstrs, DelInstrs, OperandIndices,
1267
- InstIdxForVirtReg);
1503
+ MIB->setFlags (Instr->getFlags ());
1504
+ InstIdxForVirtReg.insert (std::make_pair (AccReg, InsInstrs.size ()));
1505
+ InsInstrs.push_back (MIB);
1506
+ DelInstrs.push_back (Instr);
1507
+ }
1508
+
1509
+ SmallVector<Register, 8 > RegistersToReduce;
1510
+ for (unsigned i = (InsInstrs.size () - MaxWidth); i < InsInstrs.size ();
1511
+ ++i) {
1512
+ auto Reg = InsInstrs[i]->getOperand (0 ).getReg ();
1513
+ RegistersToReduce.push_back (Reg);
1514
+ }
1515
+
1516
+ while (RegistersToReduce.size () > 1 )
1517
+ reduceAccumulatorTree (RegistersToReduce, InsInstrs, MF, Root, MRI,
1518
+ InstIdxForVirtReg, Root.getOperand (0 ).getReg ());
1519
+
1520
+ break ;
1521
+ }
1522
+ }
1268
1523
}
1269
1524
1270
1525
MachineTraceStrategy TargetInstrInfo::getMachineCombinerTraceStrategy () const {
0 commit comments