Skip to content

Commit 969f9fa

Browse files
committed
[Machine-Combiner] Add a pass to reassociate chains of accumulation instructions into a tree (llvm#132728)
This pass is designed to increase ILP by performing accumulation into multiple registers. It currently supports only the S/UABAL accumulation instruction, but can be extended to support additional instructions. Reland of llvm#126060 which was reverted due to a conflict with llvm#131272.
1 parent 7194803 commit 969f9fa

File tree

8 files changed

+1631
-14
lines changed

8 files changed

+1631
-14
lines changed

llvm/include/llvm/CodeGen/MachineCombinerPattern.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ enum MachineCombinerPattern : unsigned {
3232
REASSOC_AX_YB,
3333
REASSOC_XA_BY,
3434
REASSOC_XA_YB,
35+
ACC_CHAIN,
3536

3637
TARGET_PATTERN_START
3738
};

llvm/include/llvm/CodeGen/TargetInstrInfo.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,41 @@ class TargetInstrInfo : public MCInstrInfo {
12561256
return false;
12571257
}
12581258

1259+
/// Find chains of accumulations that can be rewritten as a tree for increased
1260+
/// ILP.
1261+
bool getAccumulatorReassociationPatterns(
1262+
MachineInstr &Root, SmallVectorImpl<unsigned> &Patterns) const;
1263+
1264+
/// Find the chain of accumulator instructions in \P MBB and return them in
1265+
/// \P Chain.
1266+
void getAccumulatorChain(MachineInstr *CurrentInstr,
1267+
SmallVectorImpl<Register> &Chain) const;
1268+
1269+
/// Return true when \P OpCode is an instruction which performs
1270+
/// accumulation into one of its operand registers.
1271+
virtual bool isAccumulationOpcode(unsigned Opcode) const { return false; }
1272+
1273+
/// Returns an opcode which defines the accumulator used by \P Opcode.
1274+
virtual unsigned getAccumulationStartOpcode(unsigned Opcode) const {
1275+
llvm_unreachable("Function not implemented for target!");
1276+
return 0;
1277+
}
1278+
1279+
/// Returns the opcode that should be use to reduce accumulation registers.
1280+
virtual unsigned
1281+
getReduceOpcodeForAccumulator(unsigned int AccumulatorOpCode) const {
1282+
llvm_unreachable("Function not implemented for target!");
1283+
return 0;
1284+
}
1285+
1286+
/// Reduces branches of the accumulator tree into a single register.
1287+
void reduceAccumulatorTree(SmallVectorImpl<Register> &RegistersToReduce,
1288+
SmallVectorImpl<MachineInstr *> &InsInstrs,
1289+
MachineFunction &MF, MachineInstr &Root,
1290+
MachineRegisterInfo &MRI,
1291+
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg,
1292+
Register ResultReg) const;
1293+
12591294
/// Return the inverse operation opcode if it exists for \P Opcode (e.g. add
12601295
/// for sub and vice versa).
12611296
virtual std::optional<unsigned> getInverseOpcode(unsigned Opcode) const {

llvm/lib/CodeGen/TargetInstrInfo.cpp

Lines changed: 259 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "llvm/CodeGen/TargetInstrInfo.h"
14+
#include "llvm/ADT/SmallSet.h"
1415
#include "llvm/ADT/StringExtras.h"
1516
#include "llvm/BinaryFormat/Dwarf.h"
1617
#include "llvm/CodeGen/MachineCombinerPattern.h"
@@ -42,6 +43,19 @@ static cl::opt<bool> DisableHazardRecognizer(
4243
"disable-sched-hazard", cl::Hidden, cl::init(false),
4344
cl::desc("Disable hazard detection during preRA scheduling"));
4445

46+
static cl::opt<bool> EnableAccReassociation(
47+
"acc-reassoc", cl::Hidden, cl::init(true),
48+
cl::desc("Enable reassociation of accumulation chains"));
49+
50+
static cl::opt<unsigned int>
51+
MinAccumulatorDepth("acc-min-depth", cl::Hidden, cl::init(8),
52+
cl::desc("Minimum length of accumulator chains "
53+
"required for the optimization to kick in"));
54+
55+
static cl::opt<unsigned int> MaxAccumulatorWidth(
56+
"acc-max-width", cl::Hidden, cl::init(3),
57+
cl::desc("Maximum number of branches in the accumulator tree"));
58+
4559
TargetInstrInfo::~TargetInstrInfo() = default;
4660

4761
const TargetRegisterClass*
@@ -897,6 +911,154 @@ bool TargetInstrInfo::isReassociationCandidate(const MachineInstr &Inst,
897911
hasReassociableSibling(Inst, Commuted);
898912
}
899913

914+
// Utility routine that checks if \param MO is defined by an
915+
// \param CombineOpc instruction in the basic block \param MBB.
916+
// If \param CombineOpc is not provided, the OpCode check will
917+
// be skipped.
918+
static bool canCombine(MachineBasicBlock &MBB, MachineOperand &MO,
919+
unsigned CombineOpc = 0) {
920+
MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
921+
MachineInstr *MI = nullptr;
922+
923+
if (MO.isReg() && MO.getReg().isVirtual())
924+
MI = MRI.getUniqueVRegDef(MO.getReg());
925+
// And it needs to be in the trace (otherwise, it won't have a depth).
926+
if (!MI || MI->getParent() != &MBB ||
927+
((unsigned)MI->getOpcode() != CombineOpc && CombineOpc != 0))
928+
return false;
929+
// Must only used by the user we combine with.
930+
if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
931+
return false;
932+
933+
return true;
934+
}
935+
936+
// A chain of accumulation instructions will be selected IFF:
937+
// 1. All the accumulation instructions in the chain have the same opcode,
938+
// besides the first that has a slightly different opcode because it does
939+
// not accumulate into a register.
940+
// 2. All the instructions in the chain are combinable (have a single use
941+
// which itself is part of the chain).
942+
// 3. Meets the required minimum length.
943+
void TargetInstrInfo::getAccumulatorChain(
944+
MachineInstr *CurrentInstr, SmallVectorImpl<Register> &Chain) const {
945+
// Walk up the chain of accumulation instructions and collect them in the
946+
// vector.
947+
MachineBasicBlock &MBB = *CurrentInstr->getParent();
948+
const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
949+
unsigned AccumulatorOpcode = CurrentInstr->getOpcode();
950+
std::optional<unsigned> ChainStartOpCode =
951+
getAccumulationStartOpcode(AccumulatorOpcode);
952+
953+
if (!ChainStartOpCode.has_value())
954+
return;
955+
956+
// Push the first accumulator result to the start of the chain.
957+
Chain.push_back(CurrentInstr->getOperand(0).getReg());
958+
959+
// Collect the accumulator input register from all instructions in the chain.
960+
while (CurrentInstr &&
961+
canCombine(MBB, CurrentInstr->getOperand(1), AccumulatorOpcode)) {
962+
Chain.push_back(CurrentInstr->getOperand(1).getReg());
963+
CurrentInstr = MRI.getUniqueVRegDef(CurrentInstr->getOperand(1).getReg());
964+
}
965+
966+
// Add the instruction at the top of the chain.
967+
if (CurrentInstr->getOpcode() == AccumulatorOpcode &&
968+
canCombine(MBB, CurrentInstr->getOperand(1)))
969+
Chain.push_back(CurrentInstr->getOperand(1).getReg());
970+
}
971+
972+
/// Find chains of accumulations that can be rewritten as a tree for increased
973+
/// ILP.
974+
bool TargetInstrInfo::getAccumulatorReassociationPatterns(
975+
MachineInstr &Root, SmallVectorImpl<unsigned> &Patterns) const {
976+
if (!EnableAccReassociation)
977+
return false;
978+
979+
unsigned Opc = Root.getOpcode();
980+
if (!isAccumulationOpcode(Opc))
981+
return false;
982+
983+
// Verify that this is the end of the chain.
984+
MachineBasicBlock &MBB = *Root.getParent();
985+
MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
986+
if (!MRI.hasOneNonDBGUser(Root.getOperand(0).getReg()))
987+
return false;
988+
989+
auto User = MRI.use_instr_begin(Root.getOperand(0).getReg());
990+
if (User->getOpcode() == Opc)
991+
return false;
992+
993+
// Walk up the use chain and collect the reduction chain.
994+
SmallVector<Register, 32> Chain;
995+
getAccumulatorChain(&Root, Chain);
996+
997+
// Reject chains which are too short to be worth modifying.
998+
if (Chain.size() < MinAccumulatorDepth)
999+
return false;
1000+
1001+
// Check if the MBB this instruction is a part of contains any other chains.
1002+
// If so, don't apply it.
1003+
SmallSetVector<Register, 32> ReductionChain(Chain.begin(), Chain.end());
1004+
for (const auto &I : MBB) {
1005+
if (I.getOpcode() == Opc &&
1006+
!ReductionChain.contains(I.getOperand(0).getReg()))
1007+
return false;
1008+
}
1009+
1010+
Patterns.push_back(MachineCombinerPattern::ACC_CHAIN);
1011+
return true;
1012+
}
1013+
1014+
// Reduce branches of the accumulator tree by adding them together.
1015+
void TargetInstrInfo::reduceAccumulatorTree(
1016+
SmallVectorImpl<Register> &RegistersToReduce,
1017+
SmallVectorImpl<MachineInstr *> &InsInstrs, MachineFunction &MF,
1018+
MachineInstr &Root, MachineRegisterInfo &MRI,
1019+
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg,
1020+
Register ResultReg) const {
1021+
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
1022+
SmallVector<Register, 8> NewRegs;
1023+
1024+
// Get the opcode for the reduction instruction we will need to build.
1025+
// If for some reason it is not defined, early exit and don't apply this.
1026+
unsigned ReduceOpCode = getReduceOpcodeForAccumulator(Root.getOpcode());
1027+
1028+
for (unsigned int i = 1; i <= (RegistersToReduce.size() / 2); i += 2) {
1029+
auto RHS = RegistersToReduce[i - 1];
1030+
auto LHS = RegistersToReduce[i];
1031+
Register Dest;
1032+
// If we are reducing 2 registers, reuse the original result register.
1033+
if (RegistersToReduce.size() == 2)
1034+
Dest = ResultReg;
1035+
// Otherwise, create a new virtual register to hold the partial sum.
1036+
else {
1037+
auto NewVR = MRI.createVirtualRegister(
1038+
MRI.getRegClass(Root.getOperand(0).getReg()));
1039+
Dest = NewVR;
1040+
NewRegs.push_back(Dest);
1041+
InstrIdxForVirtReg.insert(std::make_pair(Dest, InsInstrs.size()));
1042+
}
1043+
1044+
// Create the new reduction instruction.
1045+
MachineInstrBuilder MIB =
1046+
BuildMI(MF, MIMetadata(Root), TII->get(ReduceOpCode), Dest)
1047+
.addReg(RHS, getKillRegState(true))
1048+
.addReg(LHS, getKillRegState(true));
1049+
// Copy any flags needed from the original instruction.
1050+
MIB->setFlags(Root.getFlags());
1051+
InsInstrs.push_back(MIB);
1052+
}
1053+
1054+
// If the number of registers to reduce is odd, add the remaining register to
1055+
// the vector of registers to reduce.
1056+
if (RegistersToReduce.size() % 2 != 0)
1057+
NewRegs.push_back(RegistersToReduce[RegistersToReduce.size() - 1]);
1058+
1059+
RegistersToReduce = NewRegs;
1060+
}
1061+
9001062
// The concept of the reassociation pass is that these operations can benefit
9011063
// from this kind of transformation:
9021064
//
@@ -936,6 +1098,8 @@ bool TargetInstrInfo::getMachineCombinerPatterns(
9361098
}
9371099
return true;
9381100
}
1101+
if (getAccumulatorReassociationPatterns(Root, Patterns))
1102+
return true;
9391103

9401104
return false;
9411105
}
@@ -947,7 +1111,12 @@ bool TargetInstrInfo::isThroughputPattern(unsigned Pattern) const {
9471111

9481112
CombinerObjective
9491113
TargetInstrInfo::getCombinerObjective(unsigned Pattern) const {
950-
return CombinerObjective::Default;
1114+
switch (Pattern) {
1115+
case MachineCombinerPattern::ACC_CHAIN:
1116+
return CombinerObjective::MustReduceDepth;
1117+
default:
1118+
return CombinerObjective::Default;
1119+
}
9511120
}
9521121

9531122
std::pair<unsigned, unsigned>
@@ -1250,19 +1419,98 @@ void TargetInstrInfo::genAlternativeCodeSequence(
12501419
SmallVectorImpl<MachineInstr *> &DelInstrs,
12511420
DenseMap<unsigned, unsigned> &InstIdxForVirtReg) const {
12521421
MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
1422+
MachineBasicBlock &MBB = *Root.getParent();
1423+
MachineFunction &MF = *MBB.getParent();
1424+
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
12531425

1254-
// Select the previous instruction in the sequence based on the input pattern.
1255-
std::array<unsigned, 5> OperandIndices;
1256-
getReassociateOperandIndices(Root, Pattern, OperandIndices);
1257-
MachineInstr *Prev =
1258-
MRI.getUniqueVRegDef(Root.getOperand(OperandIndices[0]).getReg());
1426+
switch (Pattern) {
1427+
case MachineCombinerPattern::REASSOC_AX_BY:
1428+
case MachineCombinerPattern::REASSOC_AX_YB:
1429+
case MachineCombinerPattern::REASSOC_XA_BY:
1430+
case MachineCombinerPattern::REASSOC_XA_YB: {
1431+
// Select the previous instruction in the sequence based on the input
1432+
// pattern.
1433+
std::array<unsigned, 5> OperandIndices;
1434+
getReassociateOperandIndices(Root, Pattern, OperandIndices);
1435+
MachineInstr *Prev =
1436+
MRI.getUniqueVRegDef(Root.getOperand(OperandIndices[0]).getReg());
1437+
1438+
// Don't reassociate if Prev and Root are in different blocks.
1439+
if (Prev->getParent() != Root.getParent())
1440+
return;
12591441

1260-
// Don't reassociate if Prev and Root are in different blocks.
1261-
if (Prev->getParent() != Root.getParent())
1262-
return;
1442+
reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, OperandIndices,
1443+
InstIdxForVirtReg);
1444+
break;
1445+
}
1446+
case MachineCombinerPattern::ACC_CHAIN: {
1447+
SmallVector<Register, 32> ChainRegs;
1448+
getAccumulatorChain(&Root, ChainRegs);
1449+
unsigned int Depth = ChainRegs.size();
1450+
assert(MaxAccumulatorWidth > 1 &&
1451+
"Max accumulator width set to illegal value");
1452+
unsigned int MaxWidth = Log2_32(Depth) < MaxAccumulatorWidth
1453+
? Log2_32(Depth)
1454+
: MaxAccumulatorWidth;
1455+
1456+
// Walk down the chain and rewrite it as a tree.
1457+
for (auto IndexedReg : llvm::enumerate(llvm::reverse(ChainRegs))) {
1458+
// No need to rewrite the first node, it is already perfect as it is.
1459+
if (IndexedReg.index() == 0)
1460+
continue;
1461+
1462+
MachineInstr *Instr = MRI.getUniqueVRegDef(IndexedReg.value());
1463+
MachineInstrBuilder MIB;
1464+
Register AccReg;
1465+
if (IndexedReg.index() < MaxWidth) {
1466+
// Now we need to create new instructions for the first row.
1467+
AccReg = Instr->getOperand(0).getReg();
1468+
unsigned OpCode = getAccumulationStartOpcode(Root.getOpcode());
1469+
1470+
MIB = BuildMI(MF, MIMetadata(*Instr), TII->get(OpCode), AccReg)
1471+
.addReg(Instr->getOperand(2).getReg(),
1472+
getKillRegState(Instr->getOperand(2).isKill()))
1473+
.addReg(Instr->getOperand(3).getReg(),
1474+
getKillRegState(Instr->getOperand(3).isKill()));
1475+
} else {
1476+
// For the remaining cases, we need to use an output register of one of
1477+
// the newly inserted instuctions as operand 1
1478+
AccReg = Instr->getOperand(0).getReg() == Root.getOperand(0).getReg()
1479+
? MRI.createVirtualRegister(
1480+
MRI.getRegClass(Root.getOperand(0).getReg()))
1481+
: Instr->getOperand(0).getReg();
1482+
assert(IndexedReg.index() >= MaxWidth);
1483+
auto AccumulatorInput =
1484+
ChainRegs[Depth - (IndexedReg.index() - MaxWidth) - 1];
1485+
MIB = BuildMI(MF, MIMetadata(*Instr), TII->get(Instr->getOpcode()),
1486+
AccReg)
1487+
.addReg(AccumulatorInput, getKillRegState(true))
1488+
.addReg(Instr->getOperand(2).getReg(),
1489+
getKillRegState(Instr->getOperand(2).isKill()))
1490+
.addReg(Instr->getOperand(3).getReg(),
1491+
getKillRegState(Instr->getOperand(3).isKill()));
1492+
}
12631493

1264-
reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, OperandIndices,
1265-
InstIdxForVirtReg);
1494+
MIB->setFlags(Instr->getFlags());
1495+
InstIdxForVirtReg.insert(std::make_pair(AccReg, InsInstrs.size()));
1496+
InsInstrs.push_back(MIB);
1497+
DelInstrs.push_back(Instr);
1498+
}
1499+
1500+
SmallVector<Register, 8> RegistersToReduce;
1501+
for (unsigned i = (InsInstrs.size() - MaxWidth); i < InsInstrs.size();
1502+
++i) {
1503+
auto Reg = InsInstrs[i]->getOperand(0).getReg();
1504+
RegistersToReduce.push_back(Reg);
1505+
}
1506+
1507+
while (RegistersToReduce.size() > 1)
1508+
reduceAccumulatorTree(RegistersToReduce, InsInstrs, MF, Root, MRI,
1509+
InstIdxForVirtReg, Root.getOperand(0).getReg());
1510+
1511+
break;
1512+
}
1513+
}
12661514
}
12671515

12681516
MachineTraceStrategy TargetInstrInfo::getMachineCombinerTraceStrategy() const {

0 commit comments

Comments
 (0)