Skip to content

Commit 23eb60b

Browse files
committed
[MachineCombiner] Add a pass to reassociate chains of accumulation instructions into a tree
This pass is designed to increase ILP by performing accumulation into multiple registers. It currently supports only the UABAL accumulation instruction, but can easily be extended to support additional instructions.
1 parent 90d33ee commit 23eb60b

File tree

3 files changed

+330
-18
lines changed

3 files changed

+330
-18
lines changed

llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "Utils/AArch64BaseInfo.h"
2121
#include "llvm/ADT/ArrayRef.h"
2222
#include "llvm/ADT/STLExtras.h"
23+
#include "llvm/ADT/SmallSet.h"
2324
#include "llvm/ADT/SmallVector.h"
2425
#include "llvm/CodeGen/LivePhysRegs.h"
2526
#include "llvm/CodeGen/MachineBasicBlock.h"
@@ -78,6 +79,19 @@ static cl::opt<unsigned>
7879
BDisplacementBits("aarch64-b-offset-bits", cl::Hidden, cl::init(26),
7980
cl::desc("Restrict range of B instructions (DEBUG)"));
8081

82+
static cl::opt<bool> EnableAccReassociation(
83+
"aarch64-acc-reassoc", cl::Hidden, cl::init(true),
84+
cl::desc("Enable reassociation of accumulation chains"));
85+
86+
static cl::opt<unsigned int>
87+
MinAccumulatorDepth("aarch64-acc-min-depth", cl::Hidden, cl::init(8),
88+
cl::desc("Minimum length of accumulator chains "
89+
"required for the optimization to kick in"));
90+
91+
static cl::opt<unsigned int> MaxAccumulatorWidth(
92+
"aarch64-acc-max-width", cl::Hidden, cl::init(3),
93+
cl::desc("Maximum number of branches in the accumulator tree"));
94+
8195
AArch64InstrInfo::AArch64InstrInfo(const AArch64Subtarget &STI)
8296
: AArch64GenInstrInfo(AArch64::ADJCALLSTACKDOWN, AArch64::ADJCALLSTACKUP,
8397
AArch64::CATCHRET),
@@ -6674,6 +6688,127 @@ static bool getMaddPatterns(MachineInstr &Root,
66746688
}
66756689
return Found;
66766690
}
6691+
6692+
static bool isAccumulationOpcode(unsigned Opcode) {
6693+
switch (Opcode) {
6694+
default:
6695+
break;
6696+
case AArch64::UABALB_ZZZ_D:
6697+
case AArch64::UABALB_ZZZ_H:
6698+
case AArch64::UABALB_ZZZ_S:
6699+
case AArch64::UABALT_ZZZ_D:
6700+
case AArch64::UABALT_ZZZ_H:
6701+
case AArch64::UABALT_ZZZ_S:
6702+
case AArch64::UABALv16i8_v8i16:
6703+
case AArch64::UABALv2i32_v2i64:
6704+
case AArch64::UABALv4i16_v4i32:
6705+
case AArch64::UABALv4i32_v2i64:
6706+
case AArch64::UABALv8i16_v4i32:
6707+
case AArch64::UABALv8i8_v8i16:
6708+
return true;
6709+
}
6710+
6711+
return false;
6712+
}
6713+
6714+
static unsigned getAccumulationStartOpCode(unsigned AccumulationOpcode) {
6715+
switch (AccumulationOpcode) {
6716+
default:
6717+
llvm_unreachable("Unknown accumulator opcode");
6718+
case AArch64::UABALB_ZZZ_D:
6719+
return AArch64::UABDLB_ZZZ_D;
6720+
case AArch64::UABALB_ZZZ_H:
6721+
return AArch64::UABDLB_ZZZ_H;
6722+
case AArch64::UABALB_ZZZ_S:
6723+
return AArch64::UABDLB_ZZZ_S;
6724+
case AArch64::UABALT_ZZZ_D:
6725+
return AArch64::UABDLT_ZZZ_D;
6726+
case AArch64::UABALT_ZZZ_H:
6727+
return AArch64::UABDLT_ZZZ_H;
6728+
case AArch64::UABALT_ZZZ_S:
6729+
return AArch64::UABDLT_ZZZ_S;
6730+
case AArch64::UABALv16i8_v8i16:
6731+
return AArch64::UABDLv16i8_v8i16;
6732+
case AArch64::UABALv2i32_v2i64:
6733+
return AArch64::UABDLv2i32_v2i64;
6734+
case AArch64::UABALv4i16_v4i32:
6735+
return AArch64::UABDLv4i16_v4i32;
6736+
case AArch64::UABALv4i32_v2i64:
6737+
return AArch64::UABDLv4i32_v2i64;
6738+
case AArch64::UABALv8i16_v4i32:
6739+
return AArch64::UABDLv8i16_v4i32;
6740+
case AArch64::UABALv8i8_v8i16:
6741+
return AArch64::UABDLv8i8_v8i16;
6742+
}
6743+
}
6744+
6745+
static void getAccumulatorChain(MachineInstr *CurrentInstr,
6746+
MachineBasicBlock &MBB,
6747+
MachineRegisterInfo &MRI,
6748+
SmallVectorImpl<Register> &Chain) {
6749+
// Walk up the chain of accumulation instructions and collect them in the
6750+
// vector.
6751+
unsigned AccumulatorOpcode = CurrentInstr->getOpcode();
6752+
unsigned ChainStartOpCode = getAccumulationStartOpCode(AccumulatorOpcode);
6753+
while (CurrentInstr &&
6754+
(canCombine(MBB, CurrentInstr->getOperand(1), AccumulatorOpcode) ||
6755+
canCombine(MBB, CurrentInstr->getOperand(1), ChainStartOpCode))) {
6756+
Chain.push_back(CurrentInstr->getOperand(0).getReg());
6757+
CurrentInstr = MRI.getUniqueVRegDef(CurrentInstr->getOperand(1).getReg());
6758+
}
6759+
6760+
// Add the instruction at the top of the chain.
6761+
if (CurrentInstr->getOpcode() == ChainStartOpCode)
6762+
Chain.push_back(CurrentInstr->getOperand(0).getReg());
6763+
}
6764+
6765+
/// Find chains of accumulations, likely linearized by reassocation pass,
6766+
/// that can be rewritten as a tree for increased ILP.
6767+
static bool
6768+
getAccumulatorReassociationPatterns(MachineInstr &Root,
6769+
SmallVectorImpl<unsigned> &Patterns) {
6770+
// find a chain of depth 4, which would make it profitable to rewrite
6771+
// as a tree. This pattern should be applied recursively in case we
6772+
// have a longer chain.
6773+
if (!EnableAccReassociation)
6774+
return false;
6775+
6776+
unsigned Opc = Root.getOpcode();
6777+
if (!isAccumulationOpcode(Opc))
6778+
return false;
6779+
6780+
// Verify that this is the end of the chain.
6781+
MachineBasicBlock &MBB = *Root.getParent();
6782+
MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
6783+
if (!MRI.hasOneNonDBGUser(Root.getOperand(0).getReg()))
6784+
return false;
6785+
6786+
auto User = MRI.use_instr_begin(Root.getOperand(0).getReg());
6787+
if (User->getOpcode() == Opc)
6788+
return false;
6789+
6790+
// Walk up the use chain and collect the reduction chain.
6791+
SmallVector<Register, 32> Chain;
6792+
getAccumulatorChain(&Root, MBB, MRI, Chain);
6793+
6794+
// Reject chains which are too short to be worth modifying.
6795+
if (Chain.size() < MinAccumulatorDepth)
6796+
return false;
6797+
6798+
// Check if the MBB this instruction is a part of contains any other chains.
6799+
// If so, don't apply it.
6800+
SmallSet<Register, 32> ReductionChain(Chain.begin(), Chain.end());
6801+
for (const auto &I : MBB) {
6802+
if (I.getOpcode() == Opc &&
6803+
!ReductionChain.contains(I.getOperand(0).getReg()))
6804+
return false;
6805+
}
6806+
6807+
typedef AArch64MachineCombinerPattern MCP;
6808+
Patterns.push_back(MCP::ACC_CHAIN);
6809+
return true;
6810+
}
6811+
66776812
/// Floating-Point Support
66786813

66796814
/// Find instructions that can be turned into madd.
@@ -7061,6 +7196,7 @@ AArch64InstrInfo::getCombinerObjective(unsigned Pattern) const {
70617196
switch (Pattern) {
70627197
case AArch64MachineCombinerPattern::SUBADD_OP1:
70637198
case AArch64MachineCombinerPattern::SUBADD_OP2:
7199+
case AArch64MachineCombinerPattern::ACC_CHAIN:
70647200
return CombinerObjective::MustReduceDepth;
70657201
default:
70667202
return TargetInstrInfo::getCombinerObjective(Pattern);
@@ -7078,6 +7214,8 @@ bool AArch64InstrInfo::getMachineCombinerPatterns(
70787214
// Integer patterns
70797215
if (getMaddPatterns(Root, Patterns))
70807216
return true;
7217+
if (getAccumulatorReassociationPatterns(Root, Patterns))
7218+
return true;
70817219
// Floating point patterns
70827220
if (getFMULPatterns(Root, Patterns))
70837221
return true;
@@ -7436,6 +7574,81 @@ genSubAdd2SubSub(MachineFunction &MF, MachineRegisterInfo &MRI,
74367574
DelInstrs.push_back(&Root);
74377575
}
74387576

7577+
static unsigned int
7578+
getReduceOpCodeForAccumulator(unsigned int AccumulatorOpCode) {
7579+
switch (AccumulatorOpCode) {
7580+
case AArch64::UABALB_ZZZ_D:
7581+
return AArch64::ADD_ZZZ_D;
7582+
case AArch64::UABALB_ZZZ_H:
7583+
return AArch64::ADD_ZZZ_H;
7584+
case AArch64::UABALB_ZZZ_S:
7585+
return AArch64::ADD_ZZZ_S;
7586+
case AArch64::UABALT_ZZZ_D:
7587+
return AArch64::ADD_ZZZ_D;
7588+
case AArch64::UABALT_ZZZ_H:
7589+
return AArch64::ADD_ZZZ_H;
7590+
case AArch64::UABALT_ZZZ_S:
7591+
return AArch64::ADD_ZZZ_S;
7592+
case AArch64::UABALv16i8_v8i16:
7593+
return AArch64::ADDv8i16;
7594+
case AArch64::UABALv2i32_v2i64:
7595+
return AArch64::ADDv2i64;
7596+
case AArch64::UABALv4i16_v4i32:
7597+
return AArch64::ADDv4i32;
7598+
case AArch64::UABALv4i32_v2i64:
7599+
return AArch64::ADDv2i64;
7600+
case AArch64::UABALv8i16_v4i32:
7601+
return AArch64::ADDv4i32;
7602+
case AArch64::UABALv8i8_v8i16:
7603+
return AArch64::ADDv8i16;
7604+
default:
7605+
llvm_unreachable("Unknown accumulator opcode");
7606+
}
7607+
}
7608+
7609+
// Reduce branches of the accumulator tree by adding them together.
7610+
static void reduceAccumulatorTree(
7611+
SmallVectorImpl<Register> &RegistersToReduce,
7612+
SmallVectorImpl<MachineInstr *> &InsInstrs, MachineFunction &MF,
7613+
MachineInstr &Root, MachineRegisterInfo &MRI,
7614+
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg, Register ResultReg) {
7615+
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
7616+
SmallVector<Register, 8> NewRegs;
7617+
for (unsigned int i = 1; i <= (RegistersToReduce.size() / 2); i += 2) {
7618+
auto RHS = RegistersToReduce[i - 1];
7619+
auto LHS = RegistersToReduce[i];
7620+
Register Dest;
7621+
// If we are reducing 2 registers, reuse the original result register.
7622+
if (RegistersToReduce.size() == 2)
7623+
Dest = ResultReg;
7624+
// Otherwise, create a new virtual register to hold the partial sum.
7625+
else {
7626+
auto NewVR = MRI.createVirtualRegister(
7627+
MRI.getRegClass(Root.getOperand(0).getReg()));
7628+
Dest = NewVR;
7629+
NewRegs.push_back(Dest);
7630+
InstrIdxForVirtReg.insert(std::make_pair(Dest, InsInstrs.size()));
7631+
}
7632+
7633+
// Create the new add instruction.
7634+
MachineInstrBuilder MIB =
7635+
BuildMI(MF, MIMetadata(Root),
7636+
TII->get(getReduceOpCodeForAccumulator(Root.getOpcode())), Dest)
7637+
.addReg(RHS, getKillRegState(true))
7638+
.addReg(LHS, getKillRegState(true));
7639+
// Copy any flags needed from the original instruction.
7640+
MIB->setFlags(Root.getFlags());
7641+
InsInstrs.push_back(MIB);
7642+
}
7643+
7644+
// If the number of registers to reduce is odd, add the reminaing register to
7645+
// the vector of registers to reduce.
7646+
if (RegistersToReduce.size() % 2 != 0)
7647+
NewRegs.push_back(RegistersToReduce[RegistersToReduce.size() - 1]);
7648+
7649+
RegistersToReduce = NewRegs;
7650+
}
7651+
74397652
/// When getMachineCombinerPatterns() finds potential patterns,
74407653
/// this function generates the instructions that could replace the
74417654
/// original code sequence
@@ -7671,7 +7884,76 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
76717884
MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR, RC);
76727885
break;
76737886
}
7887+
case AArch64MachineCombinerPattern::ACC_CHAIN: {
7888+
SmallVector<Register, 32> ChainRegs;
7889+
getAccumulatorChain(&Root, MBB, MRI, ChainRegs);
7890+
7891+
unsigned int Depth = ChainRegs.size();
7892+
assert(MaxAccumulatorWidth > 1 &&
7893+
"Max accumulator width set to illegal value");
7894+
unsigned int MaxWidth = Log2_32(Depth) < MaxAccumulatorWidth
7895+
? Log2_32(Depth)
7896+
: MaxAccumulatorWidth;
7897+
7898+
// Walk down the chain and rewrite it as a tree.
7899+
for (auto IndexedReg : llvm::enumerate(llvm::reverse(ChainRegs))) {
7900+
// No need to rewrite the first node, it is already perfect as it is.
7901+
if (IndexedReg.index() == 0)
7902+
continue;
7903+
7904+
MachineInstr *Instr = MRI.getUniqueVRegDef(IndexedReg.value());
7905+
MachineInstrBuilder MIB;
7906+
Register AccReg;
7907+
if (IndexedReg.index() < MaxWidth) {
7908+
// Now we need to create new instructions for the first row.
7909+
AccReg = Instr->getOperand(0).getReg();
7910+
MIB = BuildMI(
7911+
MF, MIMetadata(*Instr),
7912+
TII->get(MRI.getUniqueVRegDef(ChainRegs.back())->getOpcode()),
7913+
AccReg)
7914+
.addReg(Instr->getOperand(2).getReg(),
7915+
getKillRegState(Instr->getOperand(2).isKill()))
7916+
.addReg(Instr->getOperand(3).getReg(),
7917+
getKillRegState(Instr->getOperand(3).isKill()));
7918+
} else {
7919+
// For the remaining cases, we need ot use an output register of one of
7920+
// the newly inserted instuctions as operand 1
7921+
AccReg = Instr->getOperand(0).getReg() == Root.getOperand(0).getReg()
7922+
? MRI.createVirtualRegister(
7923+
MRI.getRegClass(Root.getOperand(0).getReg()))
7924+
: Instr->getOperand(0).getReg();
7925+
assert(IndexedReg.index() - MaxWidth >= 0);
7926+
auto AccumulatorInput =
7927+
ChainRegs[Depth - (IndexedReg.index() - MaxWidth) - 1];
7928+
MIB = BuildMI(MF, MIMetadata(*Instr), TII->get(Instr->getOpcode()),
7929+
AccReg)
7930+
.addReg(AccumulatorInput, getKillRegState(true))
7931+
.addReg(Instr->getOperand(2).getReg(),
7932+
getKillRegState(Instr->getOperand(2).isKill()))
7933+
.addReg(Instr->getOperand(3).getReg(),
7934+
getKillRegState(Instr->getOperand(3).isKill()));
7935+
}
7936+
7937+
MIB->setFlags(Instr->getFlags());
7938+
InstrIdxForVirtReg.insert(std::make_pair(AccReg, InsInstrs.size()));
7939+
InsInstrs.push_back(MIB);
7940+
DelInstrs.push_back(Instr);
7941+
}
7942+
7943+
SmallVector<Register, 8> RegistersToReduce;
7944+
for (int i = (InsInstrs.size() - MaxWidth); i < InsInstrs.size(); ++i) {
7945+
auto Reg = InsInstrs[i]->getOperand(0).getReg();
7946+
RegistersToReduce.push_back(Reg);
7947+
}
7948+
7949+
while (RegistersToReduce.size() > 1)
7950+
reduceAccumulatorTree(RegistersToReduce, InsInstrs, MF, Root, MRI,
7951+
InstrIdxForVirtReg, Root.getOperand(0).getReg());
76747952

7953+
// We don't want to break, we handle setting flags and adding Root to
7954+
// DelInstrs from here.
7955+
return;
7956+
}
76757957
case AArch64MachineCombinerPattern::MULADDv8i8_OP1:
76767958
Opc = AArch64::MLAv8i8;
76777959
RC = &AArch64::FPR64RegClass;

llvm/lib/Target/AArch64/AArch64InstrInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ enum AArch64MachineCombinerPattern : unsigned {
172172
FMULv8i16_indexed_OP2,
173173

174174
FNMADD,
175+
176+
ACC_CHAIN
175177
};
176178
class AArch64InstrInfo final : public AArch64GenInstrInfo {
177179
const AArch64RegisterInfo RI;

0 commit comments

Comments
 (0)