Skip to content

Commit 830c373

Browse files
committed
[MachineCombiner] Add a pass to reassociate chains of accumulation instructions into a tree
This pass is designed to increase ILP by performing accumulation into multiple registers. It currently supports only the UABAL accumulation instruction, but can easily be extended to support additional instructions.
1 parent 90d33ee commit 830c373

File tree

3 files changed

+288
-18
lines changed

3 files changed

+288
-18
lines changed

llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "Utils/AArch64BaseInfo.h"
2121
#include "llvm/ADT/ArrayRef.h"
2222
#include "llvm/ADT/STLExtras.h"
23+
#include "llvm/ADT/SmallSet.h"
2324
#include "llvm/ADT/SmallVector.h"
2425
#include "llvm/CodeGen/LivePhysRegs.h"
2526
#include "llvm/CodeGen/MachineBasicBlock.h"
@@ -78,6 +79,18 @@ static cl::opt<unsigned>
7879
BDisplacementBits("aarch64-b-offset-bits", cl::Hidden, cl::init(26),
7980
cl::desc("Restrict range of B instructions (DEBUG)"));
8081

82+
static cl::opt<bool>
83+
EnableAccReassociation("aarch64-acc-reassoc", cl::Hidden, cl::init(true),
84+
cl::desc("Enable reassociation of accumulation chains"));
85+
86+
static cl::opt<unsigned int>
87+
MinAccumulatorDepth("aarch64-acc-min-depth", cl::Hidden, cl::init(8),
88+
cl::desc("Minimum length of accumulator chains required for the optimization to kick in"));
89+
90+
static cl::opt<unsigned int>
91+
MaxAccumulatorWidth("aarch64-acc-max-width", cl::Hidden, cl::init(3), cl::desc("Maximum number of branches in the accumulator tree"));
92+
93+
8194
AArch64InstrInfo::AArch64InstrInfo(const AArch64Subtarget &STI)
8295
: AArch64GenInstrInfo(AArch64::ADJCALLSTACKDOWN, AArch64::ADJCALLSTACKUP,
8396
AArch64::CATCHRET),
@@ -6674,6 +6687,118 @@ static bool getMaddPatterns(MachineInstr &Root,
66746687
}
66756688
return Found;
66766689
}
6690+
6691+
static bool isAccumulationOpcode(unsigned Opcode) {
6692+
switch(Opcode) {
6693+
default:
6694+
break;
6695+
case AArch64::UABALB_ZZZ_D:
6696+
case AArch64::UABALB_ZZZ_H:
6697+
case AArch64::UABALB_ZZZ_S:
6698+
case AArch64::UABALT_ZZZ_D:
6699+
case AArch64::UABALT_ZZZ_H:
6700+
case AArch64::UABALT_ZZZ_S:
6701+
case AArch64::UABALv16i8_v8i16:
6702+
case AArch64::UABALv2i32_v2i64:
6703+
case AArch64::UABALv4i16_v4i32:
6704+
case AArch64::UABALv4i32_v2i64:
6705+
case AArch64::UABALv8i16_v4i32:
6706+
case AArch64::UABALv8i8_v8i16:
6707+
return true;
6708+
}
6709+
6710+
return false;
6711+
}
6712+
6713+
static unsigned getAccumulationStartOpCode(unsigned AccumulationOpcode) {
6714+
switch(AccumulationOpcode) {
6715+
default:
6716+
llvm_unreachable("Unknown accumulator opcode");
6717+
case AArch64::UABALB_ZZZ_D:
6718+
return AArch64::UABDLB_ZZZ_D;
6719+
case AArch64::UABALB_ZZZ_H:
6720+
return AArch64::UABDLB_ZZZ_H;
6721+
case AArch64::UABALB_ZZZ_S:
6722+
return AArch64::UABDLB_ZZZ_S;
6723+
case AArch64::UABALT_ZZZ_D:
6724+
return AArch64::UABDLT_ZZZ_D;
6725+
case AArch64::UABALT_ZZZ_H:
6726+
return AArch64::UABDLT_ZZZ_H;
6727+
case AArch64::UABALT_ZZZ_S:
6728+
return AArch64::UABDLT_ZZZ_S;
6729+
case AArch64::UABALv16i8_v8i16:
6730+
return AArch64::UABDLv16i8_v8i16;
6731+
case AArch64::UABALv2i32_v2i64:
6732+
return AArch64::UABDLv2i32_v2i64;
6733+
case AArch64::UABALv4i16_v4i32:
6734+
return AArch64::UABDLv4i16_v4i32;
6735+
case AArch64::UABALv4i32_v2i64:
6736+
return AArch64::UABDLv4i32_v2i64;
6737+
case AArch64::UABALv8i16_v4i32:
6738+
return AArch64::UABDLv8i16_v4i32;
6739+
case AArch64::UABALv8i8_v8i16:
6740+
return AArch64::UABDLv8i8_v8i16;
6741+
}
6742+
}
6743+
6744+
static void getAccumulatorChain(MachineInstr *CurrentInstr, MachineBasicBlock &MBB, MachineRegisterInfo &MRI, SmallVectorImpl<Register> &Chain) {
6745+
// Walk up the chain of accumulation instructions and collect them in the vector.
6746+
unsigned AccumulatorOpcode = CurrentInstr->getOpcode();
6747+
unsigned ChainStartOpCode = getAccumulationStartOpCode(AccumulatorOpcode);
6748+
while(CurrentInstr && (canCombine(MBB, CurrentInstr->getOperand(1), AccumulatorOpcode) || canCombine(MBB, CurrentInstr->getOperand(1), ChainStartOpCode))) {
6749+
Chain.push_back(CurrentInstr->getOperand(0).getReg());
6750+
CurrentInstr = MRI.getUniqueVRegDef(CurrentInstr->getOperand(1).getReg());
6751+
}
6752+
6753+
// Add the instruction at the top of the chain.
6754+
if (CurrentInstr->getOpcode() == ChainStartOpCode)
6755+
Chain.push_back(CurrentInstr->getOperand(0).getReg());
6756+
}
6757+
6758+
/// Find chains of accumulations, likely linearized by reassocation pass,
6759+
/// that can be rewritten as a tree for increased ILP.
6760+
static bool getAccumulatorReassociationPatterns(MachineInstr &Root,
6761+
SmallVectorImpl<unsigned> &Patterns) {
6762+
// find a chain of depth 4, which would make it profitable to rewrite
6763+
// as a tree. This pattern should be applied recursively in case we
6764+
// have a longer chain.
6765+
if (!EnableAccReassociation)
6766+
return false;
6767+
6768+
unsigned Opc = Root.getOpcode();
6769+
if (!isAccumulationOpcode(Opc))
6770+
return false;
6771+
6772+
// Verify that this is the end of the chain.
6773+
MachineBasicBlock &MBB = *Root.getParent();
6774+
MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
6775+
if (!MRI.hasOneNonDBGUser(Root.getOperand(0).getReg()))
6776+
return false;
6777+
6778+
auto User = MRI.use_instr_begin(Root.getOperand(0).getReg());
6779+
if (User->getOpcode() == Opc)
6780+
return false;
6781+
6782+
// Walk up the use chain and collect the reduction chain.
6783+
SmallVector<Register, 32> Chain;
6784+
getAccumulatorChain(&Root, MBB, MRI, Chain);
6785+
6786+
// Reject chains which are too short to be worth modifying.
6787+
if (Chain.size() < MinAccumulatorDepth)
6788+
return false;
6789+
6790+
// Check if the MBB this instruction is a part of contains any other chains. If so, don't apply it.
6791+
SmallSet<Register, 32> ReductionChain(Chain.begin(), Chain.end());
6792+
for (const auto &I : MBB) {
6793+
if (I.getOpcode() == Opc && !ReductionChain.contains(I.getOperand(0).getReg()))
6794+
return false;
6795+
}
6796+
6797+
typedef AArch64MachineCombinerPattern MCP;
6798+
Patterns.push_back(MCP::ACC_CHAIN);
6799+
return true;
6800+
}
6801+
66776802
/// Floating-Point Support
66786803

66796804
/// Find instructions that can be turned into madd.
@@ -7061,6 +7186,7 @@ AArch64InstrInfo::getCombinerObjective(unsigned Pattern) const {
70617186
switch (Pattern) {
70627187
case AArch64MachineCombinerPattern::SUBADD_OP1:
70637188
case AArch64MachineCombinerPattern::SUBADD_OP2:
7189+
case AArch64MachineCombinerPattern::ACC_CHAIN:
70647190
return CombinerObjective::MustReduceDepth;
70657191
default:
70667192
return TargetInstrInfo::getCombinerObjective(Pattern);
@@ -7078,6 +7204,8 @@ bool AArch64InstrInfo::getMachineCombinerPatterns(
70787204
// Integer patterns
70797205
if (getMaddPatterns(Root, Patterns))
70807206
return true;
7207+
if (getAccumulatorReassociationPatterns(Root, Patterns))
7208+
return true;
70817209
// Floating point patterns
70827210
if (getFMULPatterns(Root, Patterns))
70837211
return true;
@@ -7436,6 +7564,72 @@ genSubAdd2SubSub(MachineFunction &MF, MachineRegisterInfo &MRI,
74367564
DelInstrs.push_back(&Root);
74377565
}
74387566

7567+
static unsigned int getReduceOpCodeForAccumulator(unsigned int AccumulatorOpCode) {
7568+
switch (AccumulatorOpCode) {
7569+
case AArch64::UABALB_ZZZ_D:
7570+
return AArch64::ADD_ZZZ_D;
7571+
case AArch64::UABALB_ZZZ_H:
7572+
return AArch64::ADD_ZZZ_H;
7573+
case AArch64::UABALB_ZZZ_S:
7574+
return AArch64::ADD_ZZZ_S;
7575+
case AArch64::UABALT_ZZZ_D:
7576+
return AArch64::ADD_ZZZ_D;
7577+
case AArch64::UABALT_ZZZ_H:
7578+
return AArch64::ADD_ZZZ_H;
7579+
case AArch64::UABALT_ZZZ_S:
7580+
return AArch64::ADD_ZZZ_S;
7581+
case AArch64::UABALv16i8_v8i16:
7582+
return AArch64::ADDv8i16;
7583+
case AArch64::UABALv2i32_v2i64:
7584+
return AArch64::ADDv2i64;
7585+
case AArch64::UABALv4i16_v4i32:
7586+
return AArch64::ADDv4i32;
7587+
case AArch64::UABALv4i32_v2i64:
7588+
return AArch64::ADDv2i64;
7589+
case AArch64::UABALv8i16_v4i32:
7590+
return AArch64::ADDv4i32;
7591+
case AArch64::UABALv8i8_v8i16:
7592+
return AArch64::ADDv8i16;
7593+
default:
7594+
llvm_unreachable("Unknown accumulator opcode");
7595+
}
7596+
}
7597+
7598+
// Reduce branches of the accumulator tree by adding them together.
7599+
static void reduceAccumulatorTree(SmallVectorImpl<Register> &RegistersToReduce, SmallVectorImpl<MachineInstr *> &InsInstrs,
7600+
MachineFunction &MF, MachineInstr &Root, MachineRegisterInfo &MRI,
7601+
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg, Register ResultReg) {
7602+
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
7603+
SmallVector<Register, 8> NewRegs;
7604+
for (unsigned int i = 1; i <= (RegistersToReduce.size() / 2); i+=2) {
7605+
auto RHS = RegistersToReduce[i - 1];
7606+
auto LHS = RegistersToReduce[i];
7607+
Register Dest;
7608+
// If we are reducing 2 registers, reuse the original result register.
7609+
if (RegistersToReduce.size() == 2)
7610+
Dest = ResultReg;
7611+
// Otherwise, create a new virtual register to hold the partial sum.
7612+
else {
7613+
auto NewVR = MRI.createVirtualRegister(MRI.getRegClass(Root.getOperand(0).getReg()));
7614+
Dest = NewVR;
7615+
NewRegs.push_back(Dest);
7616+
InstrIdxForVirtReg.insert(std::make_pair(Dest, InsInstrs.size()));
7617+
}
7618+
7619+
// Create the new add instruction.
7620+
MachineInstrBuilder MIB = BuildMI(MF, MIMetadata(Root), TII->get(getReduceOpCodeForAccumulator(Root.getOpcode())), Dest).addReg(RHS, getKillRegState(true)).addReg(LHS, getKillRegState(true));
7621+
// Copy any flags needed from the original instruction.
7622+
MIB->setFlags(Root.getFlags());
7623+
InsInstrs.push_back(MIB);
7624+
}
7625+
7626+
// If the number of registers to reduce is odd, add the reminaing register to the vector of registers to reduce.
7627+
if (RegistersToReduce.size() % 2 != 0)
7628+
NewRegs.push_back(RegistersToReduce[RegistersToReduce.size() - 1]);
7629+
7630+
RegistersToReduce = NewRegs;
7631+
}
7632+
74397633
/// When getMachineCombinerPatterns() finds potential patterns,
74407634
/// this function generates the instructions that could replace the
74417635
/// original code sequence
@@ -7671,7 +7865,53 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
76717865
MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR, RC);
76727866
break;
76737867
}
7868+
case AArch64MachineCombinerPattern::ACC_CHAIN: {
7869+
SmallVector<Register, 32> ChainRegs;
7870+
getAccumulatorChain(&Root, MBB, MRI, ChainRegs);
7871+
7872+
unsigned int Depth = ChainRegs.size();
7873+
assert(MaxAccumulatorWidth > 1 && "Max accumulator width set to illegal value");
7874+
unsigned int MaxWidth = Log2_32(Depth) < MaxAccumulatorWidth ? Log2_32(Depth) : MaxAccumulatorWidth;
7875+
7876+
// Walk down the chain and rewrite it as a tree.
7877+
for (auto IndexedReg : llvm::enumerate(llvm::reverse(ChainRegs))) {
7878+
// No need to rewrite the first node, it is already perfect as it is.
7879+
if (IndexedReg.index() == 0)
7880+
continue;
7881+
7882+
MachineInstr *Instr = MRI.getUniqueVRegDef(IndexedReg.value());
7883+
MachineInstrBuilder MIB;
7884+
Register AccReg;
7885+
if (IndexedReg.index() < MaxWidth) {
7886+
// Now we need to create new instructions for the first row.
7887+
AccReg = Instr->getOperand(0).getReg();
7888+
MIB = BuildMI(MF, MIMetadata(*Instr), TII->get(MRI.getUniqueVRegDef(ChainRegs.back())->getOpcode()), AccReg).addReg(Instr->getOperand(2).getReg(), getKillRegState(Instr->getOperand(2).isKill())).addReg(Instr->getOperand(3).getReg(), getKillRegState(Instr->getOperand(3).isKill()));
7889+
} else {
7890+
// For the remaining cases, we need ot use an output register of one of the newly inserted instuctions as operand 1
7891+
AccReg = Instr->getOperand(0).getReg() == Root.getOperand(0).getReg() ? MRI.createVirtualRegister(MRI.getRegClass(Root.getOperand(0).getReg())) : Instr->getOperand(0).getReg();
7892+
assert(IndexedReg.index() - MaxWidth >= 0);
7893+
auto AccumulatorInput = ChainRegs[Depth - (IndexedReg.index() - MaxWidth) - 1];
7894+
MIB = BuildMI(MF, MIMetadata(*Instr), TII->get(Instr->getOpcode()), AccReg).addReg(AccumulatorInput, getKillRegState(true)).addReg(Instr->getOperand(2).getReg(), getKillRegState(Instr->getOperand(2).isKill())).addReg(Instr->getOperand(3).getReg(), getKillRegState(Instr->getOperand(3).isKill()));
7895+
}
7896+
7897+
MIB->setFlags(Instr->getFlags());
7898+
InstrIdxForVirtReg.insert(std::make_pair(AccReg, InsInstrs.size()));
7899+
InsInstrs.push_back(MIB);
7900+
DelInstrs.push_back(Instr);
7901+
}
7902+
7903+
SmallVector<Register, 8> RegistersToReduce;
7904+
for (int i = (InsInstrs.size() - MaxWidth); i < InsInstrs.size(); ++i) {
7905+
auto Reg = InsInstrs[i]->getOperand(0).getReg();
7906+
RegistersToReduce.push_back(Reg);
7907+
}
76747908

7909+
while (RegistersToReduce.size() > 1)
7910+
reduceAccumulatorTree(RegistersToReduce, InsInstrs, MF, Root, MRI, InstrIdxForVirtReg, Root.getOperand(0).getReg());
7911+
7912+
// We don't want to break, we handle setting flags and adding Root to DelInstrs from here.
7913+
return;
7914+
}
76757915
case AArch64MachineCombinerPattern::MULADDv8i8_OP1:
76767916
Opc = AArch64::MLAv8i8;
76777917
RC = &AArch64::FPR64RegClass;

llvm/lib/Target/AArch64/AArch64InstrInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ enum AArch64MachineCombinerPattern : unsigned {
172172
FMULv8i16_indexed_OP2,
173173

174174
FNMADD,
175+
176+
ACC_CHAIN
175177
};
176178
class AArch64InstrInfo final : public AArch64GenInstrInfo {
177179
const AArch64RegisterInfo RI;

0 commit comments

Comments
 (0)