Skip to content

Commit 8e67ee2

Browse files
committed
[CodeGen][TII] Allow reassociation on custom operand indices
This opens up a door for reusing reassociation optimizations on target-specific binary operations with non-standard operand list. This is effectively a NFC.
1 parent 81cdd35 commit 8e67ee2

File tree

3 files changed

+115
-49
lines changed

3 files changed

+115
-49
lines changed

llvm/include/llvm/CodeGen/TargetInstrInfo.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "llvm/MC/MCInstrInfo.h"
3131
#include "llvm/Support/BranchProbability.h"
3232
#include "llvm/Support/ErrorHandling.h"
33+
#include <array>
3334
#include <cassert>
3435
#include <cstddef>
3536
#include <cstdint>
@@ -1268,12 +1269,22 @@ class TargetInstrInfo : public MCInstrInfo {
12681269
return true;
12691270
}
12701271

1272+
/// The returned array encodes the operand index for each parameter because
1273+
/// the operands may be commuted; the operand indices for associative
1274+
/// operations might also be target-specific. Each element specifies the index
1275+
/// of {Prev, A, B, X, Y}.
1276+
virtual void
1277+
getReassociateOperandIdx(const MachineInstr &Root,
1278+
MachineCombinerPattern Pattern,
1279+
std::array<unsigned, 5> &OperandIndices) const;
1280+
12711281
/// Attempt to reassociate \P Root and \P Prev according to \P Pattern to
12721282
/// reduce critical path length.
12731283
void reassociateOps(MachineInstr &Root, MachineInstr &Prev,
12741284
MachineCombinerPattern Pattern,
12751285
SmallVectorImpl<MachineInstr *> &InsInstrs,
12761286
SmallVectorImpl<MachineInstr *> &DelInstrs,
1287+
ArrayRef<unsigned> OperandIndices,
12771288
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const;
12781289

12791290
/// Reassociation of some instructions requires inverse operations (e.g.

llvm/lib/CodeGen/TargetInstrInfo.cpp

Lines changed: 100 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,43 +1051,45 @@ static std::pair<bool, bool> mustSwapOperands(MachineCombinerPattern Pattern) {
10511051
}
10521052
}
10531053

1054+
void TargetInstrInfo::getReassociateOperandIdx(
1055+
const MachineInstr &Root, MachineCombinerPattern Pattern,
1056+
std::array<unsigned, 5> &OperandIndices) const {
1057+
switch (Pattern) {
1058+
case MachineCombinerPattern::REASSOC_AX_BY:
1059+
OperandIndices = {1, 1, 1, 2, 2};
1060+
break;
1061+
case MachineCombinerPattern::REASSOC_AX_YB:
1062+
OperandIndices = {2, 1, 2, 2, 1};
1063+
break;
1064+
case MachineCombinerPattern::REASSOC_XA_BY:
1065+
OperandIndices = {1, 2, 1, 1, 2};
1066+
break;
1067+
case MachineCombinerPattern::REASSOC_XA_YB:
1068+
OperandIndices = {2, 2, 2, 1, 1};
1069+
break;
1070+
default:
1071+
llvm_unreachable("unexpected MachineCombinerPattern");
1072+
}
1073+
}
1074+
10541075
/// Attempt the reassociation transformation to reduce critical path length.
10551076
/// See the above comments before getMachineCombinerPatterns().
10561077
void TargetInstrInfo::reassociateOps(
1057-
MachineInstr &Root, MachineInstr &Prev,
1058-
MachineCombinerPattern Pattern,
1078+
MachineInstr &Root, MachineInstr &Prev, MachineCombinerPattern Pattern,
10591079
SmallVectorImpl<MachineInstr *> &InsInstrs,
10601080
SmallVectorImpl<MachineInstr *> &DelInstrs,
1081+
ArrayRef<unsigned> OperandIndices,
10611082
DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const {
10621083
MachineFunction *MF = Root.getMF();
10631084
MachineRegisterInfo &MRI = MF->getRegInfo();
10641085
const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
10651086
const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
10661087
const TargetRegisterClass *RC = Root.getRegClassConstraint(0, TII, TRI);
10671088

1068-
// This array encodes the operand index for each parameter because the
1069-
// operands may be commuted. Each row corresponds to a pattern value,
1070-
// and each column specifies the index of A, B, X, Y.
1071-
unsigned OpIdx[4][4] = {
1072-
{ 1, 1, 2, 2 },
1073-
{ 1, 2, 2, 1 },
1074-
{ 2, 1, 1, 2 },
1075-
{ 2, 2, 1, 1 }
1076-
};
1077-
1078-
int Row;
1079-
switch (Pattern) {
1080-
case MachineCombinerPattern::REASSOC_AX_BY: Row = 0; break;
1081-
case MachineCombinerPattern::REASSOC_AX_YB: Row = 1; break;
1082-
case MachineCombinerPattern::REASSOC_XA_BY: Row = 2; break;
1083-
case MachineCombinerPattern::REASSOC_XA_YB: Row = 3; break;
1084-
default: llvm_unreachable("unexpected MachineCombinerPattern");
1085-
}
1086-
1087-
MachineOperand &OpA = Prev.getOperand(OpIdx[Row][0]);
1088-
MachineOperand &OpB = Root.getOperand(OpIdx[Row][1]);
1089-
MachineOperand &OpX = Prev.getOperand(OpIdx[Row][2]);
1090-
MachineOperand &OpY = Root.getOperand(OpIdx[Row][3]);
1089+
MachineOperand &OpA = Prev.getOperand(OperandIndices[1]);
1090+
MachineOperand &OpB = Root.getOperand(OperandIndices[2]);
1091+
MachineOperand &OpX = Prev.getOperand(OperandIndices[3]);
1092+
MachineOperand &OpY = Root.getOperand(OperandIndices[4]);
10911093
MachineOperand &OpC = Root.getOperand(0);
10921094

10931095
Register RegA = OpA.getReg();
@@ -1126,21 +1128,83 @@ void TargetInstrInfo::reassociateOps(
11261128
std::swap(KillX, KillY);
11271129
}
11281130

1131+
unsigned PrevFirstOpIdx, PrevSecondOpIdx;
1132+
unsigned RootFirstOpIdx, RootSecondOpIdx;
1133+
switch (Pattern) {
1134+
case MachineCombinerPattern::REASSOC_AX_BY:
1135+
PrevFirstOpIdx = OperandIndices[1];
1136+
PrevSecondOpIdx = OperandIndices[3];
1137+
RootFirstOpIdx = OperandIndices[2];
1138+
RootSecondOpIdx = OperandIndices[4];
1139+
break;
1140+
case MachineCombinerPattern::REASSOC_AX_YB:
1141+
PrevFirstOpIdx = OperandIndices[1];
1142+
PrevSecondOpIdx = OperandIndices[3];
1143+
RootFirstOpIdx = OperandIndices[4];
1144+
RootSecondOpIdx = OperandIndices[2];
1145+
break;
1146+
case MachineCombinerPattern::REASSOC_XA_BY:
1147+
PrevFirstOpIdx = OperandIndices[3];
1148+
PrevSecondOpIdx = OperandIndices[1];
1149+
RootFirstOpIdx = OperandIndices[2];
1150+
RootSecondOpIdx = OperandIndices[4];
1151+
break;
1152+
case MachineCombinerPattern::REASSOC_XA_YB:
1153+
PrevFirstOpIdx = OperandIndices[3];
1154+
PrevSecondOpIdx = OperandIndices[1];
1155+
RootFirstOpIdx = OperandIndices[4];
1156+
RootSecondOpIdx = OperandIndices[2];
1157+
break;
1158+
default:
1159+
llvm_unreachable("unexpected MachineCombinerPattern");
1160+
}
1161+
1162+
// Basically BuildMI but doesn't add implicit operands by default.
1163+
auto buildMINoImplicit = [](MachineFunction &MF, const MIMetadata &MIMD,
1164+
const MCInstrDesc &MCID, Register DestReg) {
1165+
return MachineInstrBuilder(
1166+
MF, MF.CreateMachineInstr(MCID, MIMD.getDL(), /*NoImpl=*/true))
1167+
.setPCSections(MIMD.getPCSections())
1168+
.addReg(DestReg, RegState::Define);
1169+
};
1170+
11291171
// Create new instructions for insertion.
11301172
MachineInstrBuilder MIB1 =
1131-
BuildMI(*MF, MIMetadata(Prev), TII->get(NewPrevOpc), NewVR)
1132-
.addReg(RegX, getKillRegState(KillX))
1133-
.addReg(RegY, getKillRegState(KillY));
1173+
buildMINoImplicit(*MF, MIMetadata(Prev), TII->get(NewPrevOpc), NewVR);
1174+
for (const auto &MO : Prev.explicit_operands()) {
1175+
unsigned Idx = MO.getOperandNo();
1176+
// Skip the result operand we'd already added.
1177+
if (Idx == 0)
1178+
continue;
1179+
if (Idx == PrevFirstOpIdx)
1180+
MIB1.addReg(RegX, getKillRegState(KillX));
1181+
else if (Idx == PrevSecondOpIdx)
1182+
MIB1.addReg(RegY, getKillRegState(KillY));
1183+
else
1184+
MIB1.add(MO);
1185+
}
1186+
MIB1.copyImplicitOps(Prev);
11341187

11351188
if (SwapRootOperands) {
11361189
std::swap(RegA, NewVR);
11371190
std::swap(KillA, KillNewVR);
11381191
}
11391192

11401193
MachineInstrBuilder MIB2 =
1141-
BuildMI(*MF, MIMetadata(Root), TII->get(NewRootOpc), RegC)
1142-
.addReg(RegA, getKillRegState(KillA))
1143-
.addReg(NewVR, getKillRegState(KillNewVR));
1194+
buildMINoImplicit(*MF, MIMetadata(Root), TII->get(NewRootOpc), RegC);
1195+
for (const auto &MO : Root.explicit_operands()) {
1196+
unsigned Idx = MO.getOperandNo();
1197+
// Skip the result operand.
1198+
if (Idx == 0)
1199+
continue;
1200+
if (Idx == RootFirstOpIdx)
1201+
MIB2 = MIB2.addReg(RegA, getKillRegState(KillA));
1202+
else if (Idx == RootSecondOpIdx)
1203+
MIB2 = MIB2.addReg(NewVR, getKillRegState(KillNewVR));
1204+
else
1205+
MIB2 = MIB2.add(MO);
1206+
}
1207+
MIB2.copyImplicitOps(Root);
11441208

11451209
// Propagate FP flags from the original instructions.
11461210
// But clear poison-generating flags because those may not be valid now.
@@ -1184,25 +1248,16 @@ void TargetInstrInfo::genAlternativeCodeSequence(
11841248
MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
11851249

11861250
// Select the previous instruction in the sequence based on the input pattern.
1187-
MachineInstr *Prev = nullptr;
1188-
switch (Pattern) {
1189-
case MachineCombinerPattern::REASSOC_AX_BY:
1190-
case MachineCombinerPattern::REASSOC_XA_BY:
1191-
Prev = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());
1192-
break;
1193-
case MachineCombinerPattern::REASSOC_AX_YB:
1194-
case MachineCombinerPattern::REASSOC_XA_YB:
1195-
Prev = MRI.getUniqueVRegDef(Root.getOperand(2).getReg());
1196-
break;
1197-
default:
1198-
llvm_unreachable("Unknown pattern for machine combiner");
1199-
}
1251+
std::array<unsigned, 5> OpIdx;
1252+
getReassociateOperandIdx(Root, Pattern, OpIdx);
1253+
MachineInstr *Prev = MRI.getUniqueVRegDef(Root.getOperand(OpIdx[0]).getReg());
12001254

12011255
// Don't reassociate if Prev and Root are in different blocks.
12021256
if (Prev->getParent() != Root.getParent())
12031257
return;
12041258

1205-
reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, InstIdxForVirtReg);
1259+
reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, OpIdx,
1260+
InstIdxForVirtReg);
12061261
}
12071262

12081263
MachineTraceStrategy TargetInstrInfo::getMachineCombinerTraceStrategy() const {

llvm/lib/Target/RISCV/RISCVInstrInfo.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,10 +1575,10 @@ void RISCVInstrInfo::finalizeInsInstrs(
15751575
MachineFunction &MF = *Root.getMF();
15761576

15771577
for (auto *NewMI : InsInstrs) {
1578-
assert(static_cast<unsigned>(RISCV::getNamedOperandIdx(
1579-
NewMI->getOpcode(), RISCV::OpName::frm)) ==
1580-
NewMI->getNumOperands() &&
1581-
"Instruction has unexpected number of operands");
1578+
// We'd already added the FRM operand.
1579+
if (static_cast<unsigned>(RISCV::getNamedOperandIdx(
1580+
NewMI->getOpcode(), RISCV::OpName::frm)) != NewMI->getNumOperands())
1581+
continue;
15821582
MachineInstrBuilder MIB(MF, NewMI);
15831583
MIB.add(FRM);
15841584
if (FRM.getImm() == RISCVFPRndMode::DYN)

0 commit comments

Comments
 (0)