Skip to content

Commit 6fd4cd5

Browse files
committed
[Machine-Combiner] Add a pass to reassociate chains of accumulation instructions into a tree
This pass is designed to increase ILP by performing accumulation into multiple registers. It currently supports only the S/UABAL accumulation instruction, but can be extended to support additional instructions.
1 parent 4398a22 commit 6fd4cd5

File tree

8 files changed

+1634
-14
lines changed

8 files changed

+1634
-14
lines changed

llvm/include/llvm/CodeGen/MachineCombinerPattern.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ enum MachineCombinerPattern : unsigned {
3232
REASSOC_AX_YB,
3333
REASSOC_XA_BY,
3434
REASSOC_XA_YB,
35+
ACC_CHAIN,
3536

3637
TARGET_PATTERN_START
3738
};

llvm/include/llvm/CodeGen/TargetInstrInfo.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,41 @@ class TargetInstrInfo : public MCInstrInfo {
12761276
return false;
12771277
}
12781278

1279+
/// Find chains of accumulations that can be rewritten as a tree for increased
1280+
/// ILP.
1281+
bool getAccumulatorReassociationPatterns(
1282+
MachineInstr &Root, SmallVectorImpl<unsigned> &Patterns) const;
1283+
1284+
/// Find the chain of accumulator instructions in \P MBB and return them in
1285+
/// \P Chain.
1286+
void getAccumulatorChain(MachineInstr *CurrentInstr,
1287+
SmallVectorImpl<Register> &Chain) const;
1288+
1289+
/// Return true when \P OpCode is an instruction which performs
1290+
/// accumulation into one of its operand registers.
1291+
virtual bool isAccumulationOpcode(unsigned Opcode) const { return false; }
1292+
1293+
/// Returns an opcode which defines the accumulator used by \P Opcode.
1294+
virtual unsigned getAccumulationStartOpcode(unsigned Opcode) const {
1295+
llvm_unreachable("Function not implemented for target!");
1296+
return 0;
1297+
}
1298+
1299+
/// Returns the opcode that should be use to reduce accumulation registers.
1300+
virtual unsigned
1301+
getReduceOpcodeForAccumulator(unsigned int AccumulatorOpCode) const {
1302+
llvm_unreachable("Function not implemented for target!");
1303+
return 0;
1304+
}
1305+
1306+
/// Reduces branches of the accumulator tree into a single register.
1307+
void reduceAccumulatorTree(SmallVectorImpl<Register> &RegistersToReduce,
1308+
SmallVectorImpl<MachineInstr *> &InsInstrs,
1309+
MachineFunction &MF, MachineInstr &Root,
1310+
MachineRegisterInfo &MRI,
1311+
DenseMap<Register, unsigned> &InstrIdxForVirtReg,
1312+
Register ResultReg) const;
1313+
12791314
/// Return the inverse operation opcode if it exists for \P Opcode (e.g. add
12801315
/// for sub and vice versa).
12811316
virtual std::optional<unsigned> getInverseOpcode(unsigned Opcode) const {

llvm/lib/CodeGen/TargetInstrInfo.cpp

Lines changed: 262 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "llvm/CodeGen/TargetInstrInfo.h"
14+
#include "llvm/ADT/SmallSet.h"
1415
#include "llvm/ADT/StringExtras.h"
1516
#include "llvm/BinaryFormat/Dwarf.h"
1617
#include "llvm/CodeGen/MachineCombinerPattern.h"
@@ -42,6 +43,19 @@ static cl::opt<bool> DisableHazardRecognizer(
4243
"disable-sched-hazard", cl::Hidden, cl::init(false),
4344
cl::desc("Disable hazard detection during preRA scheduling"));
4445

46+
static cl::opt<bool> EnableAccReassociation(
47+
"acc-reassoc", cl::Hidden, cl::init(true),
48+
cl::desc("Enable reassociation of accumulation chains"));
49+
50+
static cl::opt<unsigned int>
51+
MinAccumulatorDepth("acc-min-depth", cl::Hidden, cl::init(8),
52+
cl::desc("Minimum length of accumulator chains "
53+
"required for the optimization to kick in"));
54+
55+
static cl::opt<unsigned int> MaxAccumulatorWidth(
56+
"acc-max-width", cl::Hidden, cl::init(3),
57+
cl::desc("Maximum number of branches in the accumulator tree"));
58+
4559
TargetInstrInfo::~TargetInstrInfo() = default;
4660

4761
const TargetRegisterClass*
@@ -899,6 +913,154 @@ bool TargetInstrInfo::isReassociationCandidate(const MachineInstr &Inst,
899913
hasReassociableSibling(Inst, Commuted);
900914
}
901915

916+
// Utility routine that checks if \param MO is defined by an
917+
// \param CombineOpc instruction in the basic block \param MBB.
918+
// If \param CombineOpc is not provided, the OpCode check will
919+
// be skipped.
920+
static bool canCombine(MachineBasicBlock &MBB, MachineOperand &MO,
921+
unsigned CombineOpc = 0) {
922+
MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
923+
MachineInstr *MI = nullptr;
924+
925+
if (MO.isReg() && MO.getReg().isVirtual())
926+
MI = MRI.getUniqueVRegDef(MO.getReg());
927+
// And it needs to be in the trace (otherwise, it won't have a depth).
928+
if (!MI || MI->getParent() != &MBB ||
929+
((unsigned)MI->getOpcode() != CombineOpc && CombineOpc != 0))
930+
return false;
931+
// Must only used by the user we combine with.
932+
if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
933+
return false;
934+
935+
return true;
936+
}
937+
938+
// A chain of accumulation instructions will be selected IFF:
939+
// 1. All the accumulation instructions in the chain have the same opcode,
940+
// besides the first that has a slightly different opcode because it does
941+
// not accumulate into a register.
942+
// 2. All the instructions in the chain are combinable (have a single use
943+
// which itself is part of the chain).
944+
// 3. Meets the required minimum length.
945+
void TargetInstrInfo::getAccumulatorChain(
946+
MachineInstr *CurrentInstr, SmallVectorImpl<Register> &Chain) const {
947+
// Walk up the chain of accumulation instructions and collect them in the
948+
// vector.
949+
MachineBasicBlock &MBB = *CurrentInstr->getParent();
950+
const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
951+
unsigned AccumulatorOpcode = CurrentInstr->getOpcode();
952+
std::optional<unsigned> ChainStartOpCode =
953+
getAccumulationStartOpcode(AccumulatorOpcode);
954+
955+
if (!ChainStartOpCode.has_value())
956+
return;
957+
958+
// Push the first accumulator result to the start of the chain.
959+
Chain.push_back(CurrentInstr->getOperand(0).getReg());
960+
961+
// Collect the accumulator input register from all instructions in the chain.
962+
while (CurrentInstr &&
963+
canCombine(MBB, CurrentInstr->getOperand(1), AccumulatorOpcode)) {
964+
Chain.push_back(CurrentInstr->getOperand(1).getReg());
965+
CurrentInstr = MRI.getUniqueVRegDef(CurrentInstr->getOperand(1).getReg());
966+
}
967+
968+
// Add the instruction at the top of the chain.
969+
if (CurrentInstr->getOpcode() == AccumulatorOpcode &&
970+
canCombine(MBB, CurrentInstr->getOperand(1)))
971+
Chain.push_back(CurrentInstr->getOperand(1).getReg());
972+
}
973+
974+
/// Find chains of accumulations that can be rewritten as a tree for increased
975+
/// ILP.
976+
bool TargetInstrInfo::getAccumulatorReassociationPatterns(
977+
MachineInstr &Root, SmallVectorImpl<unsigned> &Patterns) const {
978+
if (!EnableAccReassociation)
979+
return false;
980+
981+
unsigned Opc = Root.getOpcode();
982+
if (!isAccumulationOpcode(Opc))
983+
return false;
984+
985+
// Verify that this is the end of the chain.
986+
MachineBasicBlock &MBB = *Root.getParent();
987+
MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
988+
if (!MRI.hasOneNonDBGUser(Root.getOperand(0).getReg()))
989+
return false;
990+
991+
auto User = MRI.use_instr_begin(Root.getOperand(0).getReg());
992+
if (User->getOpcode() == Opc)
993+
return false;
994+
995+
// Walk up the use chain and collect the reduction chain.
996+
SmallVector<Register, 32> Chain;
997+
getAccumulatorChain(&Root, Chain);
998+
999+
// Reject chains which are too short to be worth modifying.
1000+
if (Chain.size() < MinAccumulatorDepth)
1001+
return false;
1002+
1003+
// Check if the MBB this instruction is a part of contains any other chains.
1004+
// If so, don't apply it.
1005+
SmallSet<Register, 32> ReductionChain(Chain.begin(), Chain.end());
1006+
for (const auto &I : MBB) {
1007+
if (I.getOpcode() == Opc &&
1008+
!ReductionChain.contains(I.getOperand(0).getReg()))
1009+
return false;
1010+
}
1011+
1012+
Patterns.push_back(MachineCombinerPattern::ACC_CHAIN);
1013+
return true;
1014+
}
1015+
1016+
// Reduce branches of the accumulator tree by adding them together.
1017+
void TargetInstrInfo::reduceAccumulatorTree(
1018+
SmallVectorImpl<Register> &RegistersToReduce,
1019+
SmallVectorImpl<MachineInstr *> &InsInstrs, MachineFunction &MF,
1020+
MachineInstr &Root, MachineRegisterInfo &MRI,
1021+
DenseMap<Register, unsigned> &InstrIdxForVirtReg,
1022+
Register ResultReg) const {
1023+
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
1024+
SmallVector<Register, 8> NewRegs;
1025+
1026+
// Get the opcode for the reduction instruction we will need to build.
1027+
// If for some reason it is not defined, early exit and don't apply this.
1028+
unsigned ReduceOpCode = getReduceOpcodeForAccumulator(Root.getOpcode());
1029+
1030+
for (unsigned int i = 1; i <= (RegistersToReduce.size() / 2); i += 2) {
1031+
auto RHS = RegistersToReduce[i - 1];
1032+
auto LHS = RegistersToReduce[i];
1033+
Register Dest;
1034+
// If we are reducing 2 registers, reuse the original result register.
1035+
if (RegistersToReduce.size() == 2)
1036+
Dest = ResultReg;
1037+
// Otherwise, create a new virtual register to hold the partial sum.
1038+
else {
1039+
auto NewVR = MRI.createVirtualRegister(
1040+
MRI.getRegClass(Root.getOperand(0).getReg()));
1041+
Dest = NewVR;
1042+
NewRegs.push_back(Dest);
1043+
InstrIdxForVirtReg.insert(std::make_pair(Dest, InsInstrs.size()));
1044+
}
1045+
1046+
// Create the new reduction instruction.
1047+
MachineInstrBuilder MIB =
1048+
BuildMI(MF, MIMetadata(Root), TII->get(ReduceOpCode), Dest)
1049+
.addReg(RHS, getKillRegState(true))
1050+
.addReg(LHS, getKillRegState(true));
1051+
// Copy any flags needed from the original instruction.
1052+
MIB->setFlags(Root.getFlags());
1053+
InsInstrs.push_back(MIB);
1054+
}
1055+
1056+
// If the number of registers to reduce is odd, add the remaining register to
1057+
// the vector of registers to reduce.
1058+
if (RegistersToReduce.size() % 2 != 0)
1059+
NewRegs.push_back(RegistersToReduce[RegistersToReduce.size() - 1]);
1060+
1061+
RegistersToReduce = NewRegs;
1062+
}
1063+
9021064
// The concept of the reassociation pass is that these operations can benefit
9031065
// from this kind of transformation:
9041066
//
@@ -938,6 +1100,8 @@ bool TargetInstrInfo::getMachineCombinerPatterns(
9381100
}
9391101
return true;
9401102
}
1103+
if (getAccumulatorReassociationPatterns(Root, Patterns))
1104+
return true;
9411105

9421106
return false;
9431107
}
@@ -949,7 +1113,12 @@ bool TargetInstrInfo::isThroughputPattern(unsigned Pattern) const {
9491113

9501114
CombinerObjective
9511115
TargetInstrInfo::getCombinerObjective(unsigned Pattern) const {
952-
return CombinerObjective::Default;
1116+
switch (Pattern) {
1117+
case MachineCombinerPattern::ACC_CHAIN:
1118+
return CombinerObjective::MustReduceDepth;
1119+
default:
1120+
return CombinerObjective::Default;
1121+
}
9531122
}
9541123

9551124
std::pair<unsigned, unsigned>
@@ -1252,19 +1421,101 @@ void TargetInstrInfo::genAlternativeCodeSequence(
12521421
SmallVectorImpl<MachineInstr *> &DelInstrs,
12531422
DenseMap<Register, unsigned> &InstIdxForVirtReg) const {
12541423
MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
1424+
MachineBasicBlock &MBB = *Root.getParent();
1425+
MachineFunction &MF = *MBB.getParent();
1426+
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
12551427

1256-
// Select the previous instruction in the sequence based on the input pattern.
1257-
std::array<unsigned, 5> OperandIndices;
1258-
getReassociateOperandIndices(Root, Pattern, OperandIndices);
1259-
MachineInstr *Prev =
1260-
MRI.getUniqueVRegDef(Root.getOperand(OperandIndices[0]).getReg());
1428+
switch (Pattern) {
1429+
case MachineCombinerPattern::REASSOC_AX_BY:
1430+
case MachineCombinerPattern::REASSOC_AX_YB:
1431+
case MachineCombinerPattern::REASSOC_XA_BY:
1432+
case MachineCombinerPattern::REASSOC_XA_YB: {
1433+
// Select the previous instruction in the sequence based on the input
1434+
// pattern.
1435+
std::array<unsigned, 5> OperandIndices;
1436+
getReassociateOperandIndices(Root, Pattern, OperandIndices);
1437+
MachineInstr *Prev =
1438+
MRI.getUniqueVRegDef(Root.getOperand(OperandIndices[0]).getReg());
1439+
1440+
// Don't reassociate if Prev and Root are in different blocks.
1441+
if (Prev->getParent() != Root.getParent())
1442+
return;
12611443

1262-
// Don't reassociate if Prev and Root are in different blocks.
1263-
if (Prev->getParent() != Root.getParent())
1264-
return;
1444+
reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, OperandIndices,
1445+
InstIdxForVirtReg);
1446+
break;
1447+
}
1448+
case MachineCombinerPattern::ACC_CHAIN: {
1449+
SmallVector<Register, 32> ChainRegs;
1450+
getAccumulatorChain(&Root, ChainRegs);
1451+
unsigned int Depth = ChainRegs.size();
1452+
assert(MaxAccumulatorWidth > 1 &&
1453+
"Max accumulator width set to illegal value");
1454+
unsigned int MaxWidth = Log2_32(Depth) < MaxAccumulatorWidth
1455+
? Log2_32(Depth)
1456+
: MaxAccumulatorWidth;
1457+
1458+
// Walk down the chain and rewrite it as a tree.
1459+
for (auto IndexedReg : llvm::enumerate(llvm::reverse(ChainRegs))) {
1460+
// No need to rewrite the first node, it is already perfect as it is.
1461+
if (IndexedReg.index() == 0)
1462+
continue;
1463+
1464+
MachineInstr *Instr = MRI.getUniqueVRegDef(IndexedReg.value());
1465+
MachineInstrBuilder MIB;
1466+
Register AccReg;
1467+
if (IndexedReg.index() < MaxWidth) {
1468+
// Now we need to create new instructions for the first row.
1469+
AccReg = Instr->getOperand(0).getReg();
1470+
std::optional<unsigned> OpCode =
1471+
getAccumulationStartOpcode(Root.getOpcode());
1472+
assert(OpCode.value() &&
1473+
"Missing opcode for accumulation instruction.");
1474+
1475+
MIB = BuildMI(MF, MIMetadata(*Instr), TII->get(OpCode.value()), AccReg)
1476+
.addReg(Instr->getOperand(2).getReg(),
1477+
getKillRegState(Instr->getOperand(2).isKill()))
1478+
.addReg(Instr->getOperand(3).getReg(),
1479+
getKillRegState(Instr->getOperand(3).isKill()));
1480+
} else {
1481+
// For the remaining cases, we need to use an output register of one of
1482+
// the newly inserted instuctions as operand 1
1483+
AccReg = Instr->getOperand(0).getReg() == Root.getOperand(0).getReg()
1484+
? MRI.createVirtualRegister(
1485+
MRI.getRegClass(Root.getOperand(0).getReg()))
1486+
: Instr->getOperand(0).getReg();
1487+
assert(IndexedReg.index() - MaxWidth >= 0);
1488+
auto AccumulatorInput =
1489+
ChainRegs[Depth - (IndexedReg.index() - MaxWidth) - 1];
1490+
MIB = BuildMI(MF, MIMetadata(*Instr), TII->get(Instr->getOpcode()),
1491+
AccReg)
1492+
.addReg(AccumulatorInput, getKillRegState(true))
1493+
.addReg(Instr->getOperand(2).getReg(),
1494+
getKillRegState(Instr->getOperand(2).isKill()))
1495+
.addReg(Instr->getOperand(3).getReg(),
1496+
getKillRegState(Instr->getOperand(3).isKill()));
1497+
}
12651498

1266-
reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, OperandIndices,
1267-
InstIdxForVirtReg);
1499+
MIB->setFlags(Instr->getFlags());
1500+
InstIdxForVirtReg.insert(std::make_pair(AccReg, InsInstrs.size()));
1501+
InsInstrs.push_back(MIB);
1502+
DelInstrs.push_back(Instr);
1503+
}
1504+
1505+
SmallVector<Register, 8> RegistersToReduce;
1506+
for (unsigned i = (InsInstrs.size() - MaxWidth); i < InsInstrs.size();
1507+
++i) {
1508+
auto Reg = InsInstrs[i]->getOperand(0).getReg();
1509+
RegistersToReduce.push_back(Reg);
1510+
}
1511+
1512+
while (RegistersToReduce.size() > 1)
1513+
reduceAccumulatorTree(RegistersToReduce, InsInstrs, MF, Root, MRI,
1514+
InstIdxForVirtReg, Root.getOperand(0).getReg());
1515+
1516+
break;
1517+
}
1518+
}
12681519
}
12691520

12701521
MachineTraceStrategy TargetInstrInfo::getMachineCombinerTraceStrategy() const {

0 commit comments

Comments
 (0)