Skip to content

Commit ed03b9d

Browse files
authored
[MachineCombiner] Add a pass to reassociate chains of accumulation instructions into a tree (llvm#132728)
This pass is designed to increase ILP by performing accumulation into multiple registers. It currently supports only the S/UABAL accumulation instruction, but will be extended to support additional instructions. This is pattern appears often in code written using intrinsics, which get linearized in IR form be the Reassociate Pass, but unlike basic instructions such as add/mult/etc, do not have corresponding MachineCombiner patterns aimed at restoring the tree which is linearized in the earlier pass. rdar://78517468
2 parents 3f3fde0 + 969f9fa commit ed03b9d

File tree

8 files changed

+1631
-14
lines changed

8 files changed

+1631
-14
lines changed

llvm/include/llvm/CodeGen/MachineCombinerPattern.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ enum MachineCombinerPattern : unsigned {
3232
REASSOC_AX_YB,
3333
REASSOC_XA_BY,
3434
REASSOC_XA_YB,
35+
ACC_CHAIN,
3536

3637
TARGET_PATTERN_START
3738
};

llvm/include/llvm/CodeGen/TargetInstrInfo.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,41 @@ class TargetInstrInfo : public MCInstrInfo {
12561256
return false;
12571257
}
12581258

1259+
/// Find chains of accumulations that can be rewritten as a tree for increased
1260+
/// ILP.
1261+
bool getAccumulatorReassociationPatterns(
1262+
MachineInstr &Root, SmallVectorImpl<unsigned> &Patterns) const;
1263+
1264+
/// Find the chain of accumulator instructions in \P MBB and return them in
1265+
/// \P Chain.
1266+
void getAccumulatorChain(MachineInstr *CurrentInstr,
1267+
SmallVectorImpl<Register> &Chain) const;
1268+
1269+
/// Return true when \P OpCode is an instruction which performs
1270+
/// accumulation into one of its operand registers.
1271+
virtual bool isAccumulationOpcode(unsigned Opcode) const { return false; }
1272+
1273+
/// Returns an opcode which defines the accumulator used by \P Opcode.
1274+
virtual unsigned getAccumulationStartOpcode(unsigned Opcode) const {
1275+
llvm_unreachable("Function not implemented for target!");
1276+
return 0;
1277+
}
1278+
1279+
/// Returns the opcode that should be use to reduce accumulation registers.
1280+
virtual unsigned
1281+
getReduceOpcodeForAccumulator(unsigned int AccumulatorOpCode) const {
1282+
llvm_unreachable("Function not implemented for target!");
1283+
return 0;
1284+
}
1285+
1286+
/// Reduces branches of the accumulator tree into a single register.
1287+
void reduceAccumulatorTree(SmallVectorImpl<Register> &RegistersToReduce,
1288+
SmallVectorImpl<MachineInstr *> &InsInstrs,
1289+
MachineFunction &MF, MachineInstr &Root,
1290+
MachineRegisterInfo &MRI,
1291+
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg,
1292+
Register ResultReg) const;
1293+
12591294
/// Return the inverse operation opcode if it exists for \P Opcode (e.g. add
12601295
/// for sub and vice versa).
12611296
virtual std::optional<unsigned> getInverseOpcode(unsigned Opcode) const {

llvm/lib/CodeGen/TargetInstrInfo.cpp

Lines changed: 259 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "llvm/CodeGen/TargetInstrInfo.h"
14+
#include "llvm/ADT/SmallSet.h"
1415
#include "llvm/ADT/StringExtras.h"
1516
#include "llvm/BinaryFormat/Dwarf.h"
1617
#include "llvm/CodeGen/MachineCombinerPattern.h"
@@ -42,6 +43,19 @@ static cl::opt<bool> DisableHazardRecognizer(
4243
"disable-sched-hazard", cl::Hidden, cl::init(false),
4344
cl::desc("Disable hazard detection during preRA scheduling"));
4445

46+
static cl::opt<bool> EnableAccReassociation(
47+
"acc-reassoc", cl::Hidden, cl::init(true),
48+
cl::desc("Enable reassociation of accumulation chains"));
49+
50+
static cl::opt<unsigned int>
51+
MinAccumulatorDepth("acc-min-depth", cl::Hidden, cl::init(8),
52+
cl::desc("Minimum length of accumulator chains "
53+
"required for the optimization to kick in"));
54+
55+
static cl::opt<unsigned int> MaxAccumulatorWidth(
56+
"acc-max-width", cl::Hidden, cl::init(3),
57+
cl::desc("Maximum number of branches in the accumulator tree"));
58+
4559
TargetInstrInfo::~TargetInstrInfo() = default;
4660

4761
const TargetRegisterClass*
@@ -897,6 +911,154 @@ bool TargetInstrInfo::isReassociationCandidate(const MachineInstr &Inst,
897911
hasReassociableSibling(Inst, Commuted);
898912
}
899913

914+
// Utility routine that checks if \param MO is defined by an
915+
// \param CombineOpc instruction in the basic block \param MBB.
916+
// If \param CombineOpc is not provided, the OpCode check will
917+
// be skipped.
918+
static bool canCombine(MachineBasicBlock &MBB, MachineOperand &MO,
919+
unsigned CombineOpc = 0) {
920+
MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
921+
MachineInstr *MI = nullptr;
922+
923+
if (MO.isReg() && MO.getReg().isVirtual())
924+
MI = MRI.getUniqueVRegDef(MO.getReg());
925+
// And it needs to be in the trace (otherwise, it won't have a depth).
926+
if (!MI || MI->getParent() != &MBB ||
927+
((unsigned)MI->getOpcode() != CombineOpc && CombineOpc != 0))
928+
return false;
929+
// Must only used by the user we combine with.
930+
if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
931+
return false;
932+
933+
return true;
934+
}
935+
936+
// A chain of accumulation instructions will be selected IFF:
937+
// 1. All the accumulation instructions in the chain have the same opcode,
938+
// besides the first that has a slightly different opcode because it does
939+
// not accumulate into a register.
940+
// 2. All the instructions in the chain are combinable (have a single use
941+
// which itself is part of the chain).
942+
// 3. Meets the required minimum length.
943+
void TargetInstrInfo::getAccumulatorChain(
944+
MachineInstr *CurrentInstr, SmallVectorImpl<Register> &Chain) const {
945+
// Walk up the chain of accumulation instructions and collect them in the
946+
// vector.
947+
MachineBasicBlock &MBB = *CurrentInstr->getParent();
948+
const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
949+
unsigned AccumulatorOpcode = CurrentInstr->getOpcode();
950+
std::optional<unsigned> ChainStartOpCode =
951+
getAccumulationStartOpcode(AccumulatorOpcode);
952+
953+
if (!ChainStartOpCode.has_value())
954+
return;
955+
956+
// Push the first accumulator result to the start of the chain.
957+
Chain.push_back(CurrentInstr->getOperand(0).getReg());
958+
959+
// Collect the accumulator input register from all instructions in the chain.
960+
while (CurrentInstr &&
961+
canCombine(MBB, CurrentInstr->getOperand(1), AccumulatorOpcode)) {
962+
Chain.push_back(CurrentInstr->getOperand(1).getReg());
963+
CurrentInstr = MRI.getUniqueVRegDef(CurrentInstr->getOperand(1).getReg());
964+
}
965+
966+
// Add the instruction at the top of the chain.
967+
if (CurrentInstr->getOpcode() == AccumulatorOpcode &&
968+
canCombine(MBB, CurrentInstr->getOperand(1)))
969+
Chain.push_back(CurrentInstr->getOperand(1).getReg());
970+
}
971+
972+
/// Find chains of accumulations that can be rewritten as a tree for increased
973+
/// ILP.
974+
bool TargetInstrInfo::getAccumulatorReassociationPatterns(
975+
MachineInstr &Root, SmallVectorImpl<unsigned> &Patterns) const {
976+
if (!EnableAccReassociation)
977+
return false;
978+
979+
unsigned Opc = Root.getOpcode();
980+
if (!isAccumulationOpcode(Opc))
981+
return false;
982+
983+
// Verify that this is the end of the chain.
984+
MachineBasicBlock &MBB = *Root.getParent();
985+
MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
986+
if (!MRI.hasOneNonDBGUser(Root.getOperand(0).getReg()))
987+
return false;
988+
989+
auto User = MRI.use_instr_begin(Root.getOperand(0).getReg());
990+
if (User->getOpcode() == Opc)
991+
return false;
992+
993+
// Walk up the use chain and collect the reduction chain.
994+
SmallVector<Register, 32> Chain;
995+
getAccumulatorChain(&Root, Chain);
996+
997+
// Reject chains which are too short to be worth modifying.
998+
if (Chain.size() < MinAccumulatorDepth)
999+
return false;
1000+
1001+
// Check if the MBB this instruction is a part of contains any other chains.
1002+
// If so, don't apply it.
1003+
SmallSetVector<Register, 32> ReductionChain(Chain.begin(), Chain.end());
1004+
for (const auto &I : MBB) {
1005+
if (I.getOpcode() == Opc &&
1006+
!ReductionChain.contains(I.getOperand(0).getReg()))
1007+
return false;
1008+
}
1009+
1010+
Patterns.push_back(MachineCombinerPattern::ACC_CHAIN);
1011+
return true;
1012+
}
1013+
1014+
// Reduce branches of the accumulator tree by adding them together.
1015+
void TargetInstrInfo::reduceAccumulatorTree(
1016+
SmallVectorImpl<Register> &RegistersToReduce,
1017+
SmallVectorImpl<MachineInstr *> &InsInstrs, MachineFunction &MF,
1018+
MachineInstr &Root, MachineRegisterInfo &MRI,
1019+
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg,
1020+
Register ResultReg) const {
1021+
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
1022+
SmallVector<Register, 8> NewRegs;
1023+
1024+
// Get the opcode for the reduction instruction we will need to build.
1025+
// If for some reason it is not defined, early exit and don't apply this.
1026+
unsigned ReduceOpCode = getReduceOpcodeForAccumulator(Root.getOpcode());
1027+
1028+
for (unsigned int i = 1; i <= (RegistersToReduce.size() / 2); i += 2) {
1029+
auto RHS = RegistersToReduce[i - 1];
1030+
auto LHS = RegistersToReduce[i];
1031+
Register Dest;
1032+
// If we are reducing 2 registers, reuse the original result register.
1033+
if (RegistersToReduce.size() == 2)
1034+
Dest = ResultReg;
1035+
// Otherwise, create a new virtual register to hold the partial sum.
1036+
else {
1037+
auto NewVR = MRI.createVirtualRegister(
1038+
MRI.getRegClass(Root.getOperand(0).getReg()));
1039+
Dest = NewVR;
1040+
NewRegs.push_back(Dest);
1041+
InstrIdxForVirtReg.insert(std::make_pair(Dest, InsInstrs.size()));
1042+
}
1043+
1044+
// Create the new reduction instruction.
1045+
MachineInstrBuilder MIB =
1046+
BuildMI(MF, MIMetadata(Root), TII->get(ReduceOpCode), Dest)
1047+
.addReg(RHS, getKillRegState(true))
1048+
.addReg(LHS, getKillRegState(true));
1049+
// Copy any flags needed from the original instruction.
1050+
MIB->setFlags(Root.getFlags());
1051+
InsInstrs.push_back(MIB);
1052+
}
1053+
1054+
// If the number of registers to reduce is odd, add the remaining register to
1055+
// the vector of registers to reduce.
1056+
if (RegistersToReduce.size() % 2 != 0)
1057+
NewRegs.push_back(RegistersToReduce[RegistersToReduce.size() - 1]);
1058+
1059+
RegistersToReduce = NewRegs;
1060+
}
1061+
9001062
// The concept of the reassociation pass is that these operations can benefit
9011063
// from this kind of transformation:
9021064
//
@@ -936,6 +1098,8 @@ bool TargetInstrInfo::getMachineCombinerPatterns(
9361098
}
9371099
return true;
9381100
}
1101+
if (getAccumulatorReassociationPatterns(Root, Patterns))
1102+
return true;
9391103

9401104
return false;
9411105
}
@@ -947,7 +1111,12 @@ bool TargetInstrInfo::isThroughputPattern(unsigned Pattern) const {
9471111

9481112
CombinerObjective
9491113
TargetInstrInfo::getCombinerObjective(unsigned Pattern) const {
950-
return CombinerObjective::Default;
1114+
switch (Pattern) {
1115+
case MachineCombinerPattern::ACC_CHAIN:
1116+
return CombinerObjective::MustReduceDepth;
1117+
default:
1118+
return CombinerObjective::Default;
1119+
}
9511120
}
9521121

9531122
std::pair<unsigned, unsigned>
@@ -1250,19 +1419,98 @@ void TargetInstrInfo::genAlternativeCodeSequence(
12501419
SmallVectorImpl<MachineInstr *> &DelInstrs,
12511420
DenseMap<unsigned, unsigned> &InstIdxForVirtReg) const {
12521421
MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
1422+
MachineBasicBlock &MBB = *Root.getParent();
1423+
MachineFunction &MF = *MBB.getParent();
1424+
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
12531425

1254-
// Select the previous instruction in the sequence based on the input pattern.
1255-
std::array<unsigned, 5> OperandIndices;
1256-
getReassociateOperandIndices(Root, Pattern, OperandIndices);
1257-
MachineInstr *Prev =
1258-
MRI.getUniqueVRegDef(Root.getOperand(OperandIndices[0]).getReg());
1426+
switch (Pattern) {
1427+
case MachineCombinerPattern::REASSOC_AX_BY:
1428+
case MachineCombinerPattern::REASSOC_AX_YB:
1429+
case MachineCombinerPattern::REASSOC_XA_BY:
1430+
case MachineCombinerPattern::REASSOC_XA_YB: {
1431+
// Select the previous instruction in the sequence based on the input
1432+
// pattern.
1433+
std::array<unsigned, 5> OperandIndices;
1434+
getReassociateOperandIndices(Root, Pattern, OperandIndices);
1435+
MachineInstr *Prev =
1436+
MRI.getUniqueVRegDef(Root.getOperand(OperandIndices[0]).getReg());
1437+
1438+
// Don't reassociate if Prev and Root are in different blocks.
1439+
if (Prev->getParent() != Root.getParent())
1440+
return;
12591441

1260-
// Don't reassociate if Prev and Root are in different blocks.
1261-
if (Prev->getParent() != Root.getParent())
1262-
return;
1442+
reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, OperandIndices,
1443+
InstIdxForVirtReg);
1444+
break;
1445+
}
1446+
case MachineCombinerPattern::ACC_CHAIN: {
1447+
SmallVector<Register, 32> ChainRegs;
1448+
getAccumulatorChain(&Root, ChainRegs);
1449+
unsigned int Depth = ChainRegs.size();
1450+
assert(MaxAccumulatorWidth > 1 &&
1451+
"Max accumulator width set to illegal value");
1452+
unsigned int MaxWidth = Log2_32(Depth) < MaxAccumulatorWidth
1453+
? Log2_32(Depth)
1454+
: MaxAccumulatorWidth;
1455+
1456+
// Walk down the chain and rewrite it as a tree.
1457+
for (auto IndexedReg : llvm::enumerate(llvm::reverse(ChainRegs))) {
1458+
// No need to rewrite the first node, it is already perfect as it is.
1459+
if (IndexedReg.index() == 0)
1460+
continue;
1461+
1462+
MachineInstr *Instr = MRI.getUniqueVRegDef(IndexedReg.value());
1463+
MachineInstrBuilder MIB;
1464+
Register AccReg;
1465+
if (IndexedReg.index() < MaxWidth) {
1466+
// Now we need to create new instructions for the first row.
1467+
AccReg = Instr->getOperand(0).getReg();
1468+
unsigned OpCode = getAccumulationStartOpcode(Root.getOpcode());
1469+
1470+
MIB = BuildMI(MF, MIMetadata(*Instr), TII->get(OpCode), AccReg)
1471+
.addReg(Instr->getOperand(2).getReg(),
1472+
getKillRegState(Instr->getOperand(2).isKill()))
1473+
.addReg(Instr->getOperand(3).getReg(),
1474+
getKillRegState(Instr->getOperand(3).isKill()));
1475+
} else {
1476+
// For the remaining cases, we need to use an output register of one of
1477+
// the newly inserted instuctions as operand 1
1478+
AccReg = Instr->getOperand(0).getReg() == Root.getOperand(0).getReg()
1479+
? MRI.createVirtualRegister(
1480+
MRI.getRegClass(Root.getOperand(0).getReg()))
1481+
: Instr->getOperand(0).getReg();
1482+
assert(IndexedReg.index() >= MaxWidth);
1483+
auto AccumulatorInput =
1484+
ChainRegs[Depth - (IndexedReg.index() - MaxWidth) - 1];
1485+
MIB = BuildMI(MF, MIMetadata(*Instr), TII->get(Instr->getOpcode()),
1486+
AccReg)
1487+
.addReg(AccumulatorInput, getKillRegState(true))
1488+
.addReg(Instr->getOperand(2).getReg(),
1489+
getKillRegState(Instr->getOperand(2).isKill()))
1490+
.addReg(Instr->getOperand(3).getReg(),
1491+
getKillRegState(Instr->getOperand(3).isKill()));
1492+
}
12631493

1264-
reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, OperandIndices,
1265-
InstIdxForVirtReg);
1494+
MIB->setFlags(Instr->getFlags());
1495+
InstIdxForVirtReg.insert(std::make_pair(AccReg, InsInstrs.size()));
1496+
InsInstrs.push_back(MIB);
1497+
DelInstrs.push_back(Instr);
1498+
}
1499+
1500+
SmallVector<Register, 8> RegistersToReduce;
1501+
for (unsigned i = (InsInstrs.size() - MaxWidth); i < InsInstrs.size();
1502+
++i) {
1503+
auto Reg = InsInstrs[i]->getOperand(0).getReg();
1504+
RegistersToReduce.push_back(Reg);
1505+
}
1506+
1507+
while (RegistersToReduce.size() > 1)
1508+
reduceAccumulatorTree(RegistersToReduce, InsInstrs, MF, Root, MRI,
1509+
InstIdxForVirtReg, Root.getOperand(0).getReg());
1510+
1511+
break;
1512+
}
1513+
}
12661514
}
12671515

12681516
MachineTraceStrategy TargetInstrInfo::getMachineCombinerTraceStrategy() const {

0 commit comments

Comments
 (0)