Skip to content

Commit 35cff27

Browse files
committed
[RISCV][VLOPT] Compute demanded VLs up front. NFC
This replaces the worklist by instead computing what VL is demanded by each instruction's users first. checkUsers essentially already did this, so it's been renamed to computeDemandedVL. The demanded VLs are stored in a DenseMap, and then we can just do a single forward pass of tryReduceVL where we check if a candidate's demanded VL is less than its VLOp. This means the pass should now be in linear complexity, and allows us to relax the restriction on tied operands in more easily as in llvm#124066. Note that in order to avoid std::optional inside the DenseMap, I've removed the std::optionals and replaced them with VLMAX or 0 constant operands.
1 parent cabc640 commit 35cff27

File tree

5 files changed

+380
-85
lines changed

5 files changed

+380
-85
lines changed

llvm/lib/Target/RISCV/RISCVInstrInfo.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4235,6 +4235,8 @@ unsigned RISCV::getDestLog2EEW(const MCInstrDesc &Desc, unsigned Log2SEW) {
42354235

42364236
/// Given two VL operands, do we know that LHS <= RHS?
42374237
bool RISCV::isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS) {
4238+
if (LHS.isImm() && LHS.getImm() == 0)
4239+
return true;
42384240
if (LHS.isReg() && RHS.isReg() && LHS.getReg().isVirtual() &&
42394241
LHS.getReg() == RHS.getReg())
42404242
return true;

llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp

Lines changed: 71 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ namespace {
3333
class RISCVVLOptimizer : public MachineFunctionPass {
3434
const MachineRegisterInfo *MRI;
3535
const MachineDominatorTree *MDT;
36+
const TargetInstrInfo *TII;
3637

3738
public:
3839
static char ID;
@@ -50,12 +51,15 @@ class RISCVVLOptimizer : public MachineFunctionPass {
5051
StringRef getPassName() const override { return PASS_NAME; }
5152

5253
private:
53-
std::optional<MachineOperand> getMinimumVLForUser(MachineOperand &UserOp);
54-
/// Returns the largest common VL MachineOperand that may be used to optimize
55-
/// MI. Returns std::nullopt if it failed to find a suitable VL.
56-
std::optional<MachineOperand> checkUsers(MachineInstr &MI);
54+
MachineOperand getMinimumVLForUser(MachineOperand &UserOp);
55+
/// Computes the VL of \p MI that is actually used by its users.
56+
MachineOperand computeDemandedVL(const MachineInstr &MI);
5757
bool tryReduceVL(MachineInstr &MI);
5858
bool isCandidate(const MachineInstr &MI) const;
59+
60+
/// For a given instruction, records what elements of it are demanded by
61+
/// downstream users.
62+
DenseMap<const MachineInstr *, MachineOperand> DemandedVLs;
5963
};
6064

6165
} // end anonymous namespace
@@ -1173,15 +1177,14 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
11731177
return true;
11741178
}
11751179

1176-
std::optional<MachineOperand>
1177-
RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
1180+
MachineOperand RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
11781181
const MachineInstr &UserMI = *UserOp.getParent();
11791182
const MCInstrDesc &Desc = UserMI.getDesc();
11801183

11811184
if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) {
11821185
LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that"
11831186
" use VLMAX\n");
1184-
return std::nullopt;
1187+
return MachineOperand::CreateImm(RISCV::VLMaxSentinel);
11851188
}
11861189

11871190
// Instructions like reductions may use a vector register as a scalar
@@ -1201,46 +1204,59 @@ RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
12011204
// Looking for an immediate or a register VL that isn't X0.
12021205
assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) &&
12031206
"Did not expect X0 VL");
1207+
1208+
// If we know the demanded VL of UserMI, then we can reduce the VL it
1209+
// requires.
1210+
if (DemandedVLs.contains(&UserMI)) {
1211+
// We can only shrink the demanded VL if the elementwise result doesn't
1212+
// depend on VL (i.e. not vredsum/viota etc.)
1213+
// Also conservatively restrict to supported instructions for now.
1214+
// TODO: Can we remove the isSupportedInstr check?
1215+
if (!RISCVII::elementsDependOnVL(
1216+
TII->get(RISCV::getRVVMCOpcode(UserMI.getOpcode())).TSFlags) &&
1217+
isSupportedInstr(UserMI)) {
1218+
const MachineOperand &DemandedVL = DemandedVLs.at(&UserMI);
1219+
if (RISCV::isVLKnownLE(DemandedVL, VLOp))
1220+
return DemandedVL;
1221+
}
1222+
}
1223+
12041224
return VLOp;
12051225
}
12061226

1207-
std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
1208-
// FIXME: Avoid visiting each user for each time we visit something on the
1209-
// worklist, combined with an extra visit from the outer loop. Restructure
1210-
// along lines of an instcombine style worklist which integrates the outer
1211-
// pass.
1212-
std::optional<MachineOperand> CommonVL;
1227+
MachineOperand RISCVVLOptimizer::computeDemandedVL(const MachineInstr &MI) {
1228+
const MachineOperand &VLMAX = MachineOperand::CreateImm(RISCV::VLMaxSentinel);
1229+
MachineOperand DemandedVL = MachineOperand::CreateImm(0);
1230+
12131231
for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg())) {
12141232
const MachineInstr &UserMI = *UserOp.getParent();
12151233
LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n");
12161234
if (mayReadPastVL(UserMI)) {
12171235
LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n");
1218-
return std::nullopt;
1236+
return VLMAX;
12191237
}
12201238

12211239
// If used as a passthru, elements past VL will be read.
12221240
if (UserOp.isTied()) {
12231241
LLVM_DEBUG(dbgs() << " Abort because user used as tied operand\n");
1224-
return std::nullopt;
1242+
return VLMAX;
12251243
}
12261244

1227-
auto VLOp = getMinimumVLForUser(UserOp);
1228-
if (!VLOp)
1229-
return std::nullopt;
1245+
const MachineOperand &VLOp = getMinimumVLForUser(UserOp);
12301246

12311247
// Use the largest VL among all the users. If we cannot determine this
12321248
// statically, then we cannot optimize the VL.
1233-
if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp)) {
1234-
CommonVL = *VLOp;
1235-
LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n");
1236-
} else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL)) {
1249+
if (RISCV::isVLKnownLE(DemandedVL, VLOp)) {
1250+
DemandedVL = VLOp;
1251+
LLVM_DEBUG(dbgs() << " Demanded VL is: " << VLOp << "\n");
1252+
} else if (!RISCV::isVLKnownLE(VLOp, DemandedVL)) {
12371253
LLVM_DEBUG(dbgs() << " Abort because cannot determine a common VL\n");
1238-
return std::nullopt;
1254+
return VLMAX;
12391255
}
12401256

12411257
if (!RISCVII::hasSEWOp(UserMI.getDesc().TSFlags)) {
12421258
LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n");
1243-
return std::nullopt;
1259+
return VLMAX;
12441260
}
12451261

12461262
std::optional<OperandInfo> ConsumerInfo = getOperandInfo(UserOp, MRI);
@@ -1250,7 +1266,7 @@ std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
12501266
LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n");
12511267
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
12521268
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
1253-
return std::nullopt;
1269+
return VLMAX;
12541270
}
12551271

12561272
// If the operand is used as a scalar operand, then the EEW must be
@@ -1265,53 +1281,51 @@ std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
12651281
<< " Abort due to incompatible information for EMUL or EEW.\n");
12661282
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
12671283
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
1268-
return std::nullopt;
1284+
return VLMAX;
12691285
}
12701286
}
12711287

1272-
return CommonVL;
1288+
return DemandedVL;
12731289
}
12741290

12751291
bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) {
12761292
LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");
12771293

1278-
auto CommonVL = checkUsers(MI);
1279-
if (!CommonVL)
1280-
return false;
1294+
const MachineOperand &CommonVL = DemandedVLs.at(&MI);
12811295

1282-
assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) &&
1296+
assert((CommonVL.isImm() || CommonVL.getReg().isVirtual()) &&
12831297
"Expected VL to be an Imm or virtual Reg");
12841298

12851299
unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
12861300
MachineOperand &VLOp = MI.getOperand(VLOpNum);
12871301

1288-
if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) {
1289-
LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n");
1302+
if (!RISCV::isVLKnownLE(CommonVL, VLOp)) {
1303+
LLVM_DEBUG(dbgs() << " Abort due to DemandedVL not <= VLOp.\n");
12901304
return false;
12911305
}
12921306

1293-
if (CommonVL->isIdenticalTo(VLOp)) {
1307+
if (CommonVL.isIdenticalTo(VLOp)) {
12941308
LLVM_DEBUG(
1295-
dbgs() << " Abort due to CommonVL == VLOp, no point in reducing.\n");
1309+
dbgs()
1310+
<< " Abort due to DemandedVL == VLOp, no point in reducing.\n");
12961311
return false;
12971312
}
12981313

1299-
if (CommonVL->isImm()) {
1314+
if (CommonVL.isImm()) {
13001315
LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
1301-
<< CommonVL->getImm() << " for " << MI << "\n");
1302-
VLOp.ChangeToImmediate(CommonVL->getImm());
1316+
<< CommonVL.getImm() << " for " << MI << "\n");
1317+
VLOp.ChangeToImmediate(CommonVL.getImm());
13031318
return true;
13041319
}
1305-
const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg());
1320+
const MachineInstr *VLMI = MRI->getVRegDef(CommonVL.getReg());
13061321
if (!MDT->dominates(VLMI, &MI))
13071322
return false;
1308-
LLVM_DEBUG(
1309-
dbgs() << " Reduce VL from " << VLOp << " to "
1310-
<< printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo())
1311-
<< " for " << MI << "\n");
1323+
LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
1324+
<< printReg(CommonVL.getReg(), MRI->getTargetRegisterInfo())
1325+
<< " for " << MI << "\n");
13121326

13131327
// All our checks passed. We can reduce VL.
1314-
VLOp.ChangeToRegister(CommonVL->getReg(), false);
1328+
VLOp.ChangeToRegister(CommonVL.getReg(), false);
13151329
return true;
13161330
}
13171331

@@ -1326,52 +1340,33 @@ bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
13261340
if (!ST.hasVInstructions())
13271341
return false;
13281342

1329-
SetVector<MachineInstr *> Worklist;
1330-
auto PushOperands = [this, &Worklist](MachineInstr &MI,
1331-
bool IgnoreSameBlock) {
1332-
for (auto &Op : MI.operands()) {
1333-
if (!Op.isReg() || !Op.isUse() || !Op.getReg().isVirtual() ||
1334-
!isVectorRegClass(Op.getReg(), MRI))
1335-
continue;
1336-
1337-
MachineInstr *DefMI = MRI->getVRegDef(Op.getReg());
1338-
if (!isCandidate(*DefMI))
1339-
continue;
1340-
1341-
if (IgnoreSameBlock && DefMI->getParent() == MI.getParent())
1342-
continue;
1343-
1344-
Worklist.insert(DefMI);
1345-
}
1346-
};
1343+
TII = ST.getInstrInfo();
13471344

1348-
// Do a first pass eagerly rewriting in roughly reverse instruction
1349-
// order, populate the worklist with any instructions we might need to
1350-
// revisit. We avoid adding definitions to the worklist if they're
1351-
// in the same block - we're about to visit them anyways.
13521345
bool MadeChange = false;
13531346
for (MachineBasicBlock &MBB : MF) {
13541347
// Avoid unreachable blocks as they have degenerate dominance
13551348
if (!MDT->isReachableFromEntry(&MBB))
13561349
continue;
13571350

1358-
for (auto &MI : reverse(MBB)) {
1351+
// For each instruction that defines a vector, compute what VL its
1352+
// downstream users demand.
1353+
for (const auto &MI : reverse(MBB)) {
1354+
if (!isCandidate(MI))
1355+
continue;
1356+
DemandedVLs.insert({&MI, computeDemandedVL(MI)});
1357+
}
1358+
1359+
// Then go through and see if we can reduce the VL of any instructions to
1360+
// only what's demanded.
1361+
for (auto &MI : MBB) {
13591362
if (!isCandidate(MI))
13601363
continue;
13611364
if (!tryReduceVL(MI))
13621365
continue;
13631366
MadeChange = true;
1364-
PushOperands(MI, /*IgnoreSameBlock*/ true);
13651367
}
1366-
}
13671368

1368-
while (!Worklist.empty()) {
1369-
assert(MadeChange);
1370-
MachineInstr &MI = *Worklist.pop_back_val();
1371-
assert(isCandidate(MI));
1372-
if (!tryReduceVL(MI))
1373-
continue;
1374-
PushOperands(MI, /*IgnoreSameBlock*/ false);
1369+
DemandedVLs.clear();
13751370
}
13761371

13771372
return MadeChange;

0 commit comments

Comments
 (0)