@@ -33,6 +33,7 @@ namespace {
33
33
class RISCVVLOptimizer : public MachineFunctionPass {
34
34
const MachineRegisterInfo *MRI;
35
35
const MachineDominatorTree *MDT;
36
+ const TargetInstrInfo *TII;
36
37
37
38
public:
38
39
static char ID;
@@ -50,12 +51,15 @@ class RISCVVLOptimizer : public MachineFunctionPass {
50
51
StringRef getPassName () const override { return PASS_NAME; }
51
52
52
53
private:
53
- std::optional<MachineOperand> getMinimumVLForUser (MachineOperand &UserOp);
54
- // / Returns the largest common VL MachineOperand that may be used to optimize
55
- // / MI. Returns std::nullopt if it failed to find a suitable VL.
56
- std::optional<MachineOperand> checkUsers (MachineInstr &MI);
54
+ MachineOperand getMinimumVLForUser (MachineOperand &UserOp);
55
+ // / Computes the VL of \p MI that is actually used by its users.
56
+ MachineOperand computeDemandedVL (const MachineInstr &MI);
57
57
bool tryReduceVL (MachineInstr &MI);
58
58
bool isCandidate (const MachineInstr &MI) const ;
59
+
60
+ // / For a given instruction, records what elements of it are demanded by
61
+ // / downstream users.
62
+ DenseMap<const MachineInstr *, MachineOperand> DemandedVLs;
59
63
};
60
64
61
65
} // end anonymous namespace
@@ -1173,15 +1177,14 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
1173
1177
return true ;
1174
1178
}
1175
1179
1176
- std::optional<MachineOperand>
1177
- RISCVVLOptimizer::getMinimumVLForUser (MachineOperand &UserOp) {
1180
+ MachineOperand RISCVVLOptimizer::getMinimumVLForUser (MachineOperand &UserOp) {
1178
1181
const MachineInstr &UserMI = *UserOp.getParent ();
1179
1182
const MCInstrDesc &Desc = UserMI.getDesc ();
1180
1183
1181
1184
if (!RISCVII::hasVLOp (Desc.TSFlags ) || !RISCVII::hasSEWOp (Desc.TSFlags )) {
1182
1185
LLVM_DEBUG (dbgs () << " Abort due to lack of VL, assume that"
1183
1186
" use VLMAX\n " );
1184
- return std::nullopt ;
1187
+ return MachineOperand::CreateImm (RISCV::VLMaxSentinel) ;
1185
1188
}
1186
1189
1187
1190
// Instructions like reductions may use a vector register as a scalar
@@ -1201,46 +1204,59 @@ RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
1201
1204
// Looking for an immediate or a register VL that isn't X0.
1202
1205
assert ((!VLOp.isReg () || VLOp.getReg () != RISCV::X0) &&
1203
1206
" Did not expect X0 VL" );
1207
+
1208
+ // If we know the demanded VL of UserMI, then we can reduce the VL it
1209
+ // requires.
1210
+ if (DemandedVLs.contains (&UserMI)) {
1211
+ // We can only shrink the demanded VL if the elementwise result doesn't
1212
+ // depend on VL (i.e. not vredsum/viota etc.)
1213
+ // Also conservatively restrict to supported instructions for now.
1214
+ // TODO: Can we remove the isSupportedInstr check?
1215
+ if (!RISCVII::elementsDependOnVL (
1216
+ TII->get (RISCV::getRVVMCOpcode (UserMI.getOpcode ())).TSFlags ) &&
1217
+ isSupportedInstr (UserMI)) {
1218
+ const MachineOperand &DemandedVL = DemandedVLs.at (&UserMI);
1219
+ if (RISCV::isVLKnownLE (DemandedVL, VLOp))
1220
+ return DemandedVL;
1221
+ }
1222
+ }
1223
+
1204
1224
return VLOp;
1205
1225
}
1206
1226
1207
- std::optional<MachineOperand> RISCVVLOptimizer::checkUsers (MachineInstr &MI) {
1208
- // FIXME: Avoid visiting each user for each time we visit something on the
1209
- // worklist, combined with an extra visit from the outer loop. Restructure
1210
- // along lines of an instcombine style worklist which integrates the outer
1211
- // pass.
1212
- std::optional<MachineOperand> CommonVL;
1227
+ MachineOperand RISCVVLOptimizer::computeDemandedVL (const MachineInstr &MI) {
1228
+ const MachineOperand &VLMAX = MachineOperand::CreateImm (RISCV::VLMaxSentinel);
1229
+ MachineOperand DemandedVL = MachineOperand::CreateImm (0 );
1230
+
1213
1231
for (auto &UserOp : MRI->use_operands (MI.getOperand (0 ).getReg ())) {
1214
1232
const MachineInstr &UserMI = *UserOp.getParent ();
1215
1233
LLVM_DEBUG (dbgs () << " Checking user: " << UserMI << " \n " );
1216
1234
if (mayReadPastVL (UserMI)) {
1217
1235
LLVM_DEBUG (dbgs () << " Abort because used by unsafe instruction\n " );
1218
- return std::nullopt ;
1236
+ return VLMAX ;
1219
1237
}
1220
1238
1221
1239
// If used as a passthru, elements past VL will be read.
1222
1240
if (UserOp.isTied ()) {
1223
1241
LLVM_DEBUG (dbgs () << " Abort because user used as tied operand\n " );
1224
- return std::nullopt ;
1242
+ return VLMAX ;
1225
1243
}
1226
1244
1227
- auto VLOp = getMinimumVLForUser (UserOp);
1228
- if (!VLOp)
1229
- return std::nullopt;
1245
+ const MachineOperand &VLOp = getMinimumVLForUser (UserOp);
1230
1246
1231
1247
// Use the largest VL among all the users. If we cannot determine this
1232
1248
// statically, then we cannot optimize the VL.
1233
- if (!CommonVL || RISCV::isVLKnownLE (*CommonVL, * VLOp)) {
1234
- CommonVL = * VLOp;
1235
- LLVM_DEBUG (dbgs () << " User VL is: " << VLOp << " \n " );
1236
- } else if (!RISCV::isVLKnownLE (* VLOp, *CommonVL )) {
1249
+ if (RISCV::isVLKnownLE (DemandedVL, VLOp)) {
1250
+ DemandedVL = VLOp;
1251
+ LLVM_DEBUG (dbgs () << " Demanded VL is: " << VLOp << " \n " );
1252
+ } else if (!RISCV::isVLKnownLE (VLOp, DemandedVL )) {
1237
1253
LLVM_DEBUG (dbgs () << " Abort because cannot determine a common VL\n " );
1238
- return std::nullopt ;
1254
+ return VLMAX ;
1239
1255
}
1240
1256
1241
1257
if (!RISCVII::hasSEWOp (UserMI.getDesc ().TSFlags )) {
1242
1258
LLVM_DEBUG (dbgs () << " Abort due to lack of SEW operand\n " );
1243
- return std::nullopt ;
1259
+ return VLMAX ;
1244
1260
}
1245
1261
1246
1262
std::optional<OperandInfo> ConsumerInfo = getOperandInfo (UserOp, MRI);
@@ -1250,7 +1266,7 @@ std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
1250
1266
LLVM_DEBUG (dbgs () << " Abort due to unknown operand information.\n " );
1251
1267
LLVM_DEBUG (dbgs () << " ConsumerInfo is: " << ConsumerInfo << " \n " );
1252
1268
LLVM_DEBUG (dbgs () << " ProducerInfo is: " << ProducerInfo << " \n " );
1253
- return std::nullopt ;
1269
+ return VLMAX ;
1254
1270
}
1255
1271
1256
1272
// If the operand is used as a scalar operand, then the EEW must be
@@ -1265,53 +1281,51 @@ std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
1265
1281
<< " Abort due to incompatible information for EMUL or EEW.\n " );
1266
1282
LLVM_DEBUG (dbgs () << " ConsumerInfo is: " << ConsumerInfo << " \n " );
1267
1283
LLVM_DEBUG (dbgs () << " ProducerInfo is: " << ProducerInfo << " \n " );
1268
- return std::nullopt ;
1284
+ return VLMAX ;
1269
1285
}
1270
1286
}
1271
1287
1272
- return CommonVL ;
1288
+ return DemandedVL ;
1273
1289
}
1274
1290
1275
1291
bool RISCVVLOptimizer::tryReduceVL (MachineInstr &MI) {
1276
1292
LLVM_DEBUG (dbgs () << " Trying to reduce VL for " << MI << " \n " );
1277
1293
1278
- auto CommonVL = checkUsers (MI);
1279
- if (!CommonVL)
1280
- return false ;
1294
+ const MachineOperand &CommonVL = DemandedVLs.at (&MI);
1281
1295
1282
- assert ((CommonVL-> isImm () || CommonVL-> getReg ().isVirtual ()) &&
1296
+ assert ((CommonVL. isImm () || CommonVL. getReg ().isVirtual ()) &&
1283
1297
" Expected VL to be an Imm or virtual Reg" );
1284
1298
1285
1299
unsigned VLOpNum = RISCVII::getVLOpNum (MI.getDesc ());
1286
1300
MachineOperand &VLOp = MI.getOperand (VLOpNum);
1287
1301
1288
- if (!RISCV::isVLKnownLE (* CommonVL, VLOp)) {
1289
- LLVM_DEBUG (dbgs () << " Abort due to CommonVL not <= VLOp.\n " );
1302
+ if (!RISCV::isVLKnownLE (CommonVL, VLOp)) {
1303
+ LLVM_DEBUG (dbgs () << " Abort due to DemandedVL not <= VLOp.\n " );
1290
1304
return false ;
1291
1305
}
1292
1306
1293
- if (CommonVL-> isIdenticalTo (VLOp)) {
1307
+ if (CommonVL. isIdenticalTo (VLOp)) {
1294
1308
LLVM_DEBUG (
1295
- dbgs () << " Abort due to CommonVL == VLOp, no point in reducing.\n " );
1309
+ dbgs ()
1310
+ << " Abort due to DemandedVL == VLOp, no point in reducing.\n " );
1296
1311
return false ;
1297
1312
}
1298
1313
1299
- if (CommonVL-> isImm ()) {
1314
+ if (CommonVL. isImm ()) {
1300
1315
LLVM_DEBUG (dbgs () << " Reduce VL from " << VLOp << " to "
1301
- << CommonVL-> getImm () << " for " << MI << " \n " );
1302
- VLOp.ChangeToImmediate (CommonVL-> getImm ());
1316
+ << CommonVL. getImm () << " for " << MI << " \n " );
1317
+ VLOp.ChangeToImmediate (CommonVL. getImm ());
1303
1318
return true ;
1304
1319
}
1305
- const MachineInstr *VLMI = MRI->getVRegDef (CommonVL-> getReg ());
1320
+ const MachineInstr *VLMI = MRI->getVRegDef (CommonVL. getReg ());
1306
1321
if (!MDT->dominates (VLMI, &MI))
1307
1322
return false ;
1308
- LLVM_DEBUG (
1309
- dbgs () << " Reduce VL from " << VLOp << " to "
1310
- << printReg (CommonVL->getReg (), MRI->getTargetRegisterInfo ())
1311
- << " for " << MI << " \n " );
1323
+ LLVM_DEBUG (dbgs () << " Reduce VL from " << VLOp << " to "
1324
+ << printReg (CommonVL.getReg (), MRI->getTargetRegisterInfo ())
1325
+ << " for " << MI << " \n " );
1312
1326
1313
1327
// All our checks passed. We can reduce VL.
1314
- VLOp.ChangeToRegister (CommonVL-> getReg (), false );
1328
+ VLOp.ChangeToRegister (CommonVL. getReg (), false );
1315
1329
return true ;
1316
1330
}
1317
1331
@@ -1326,52 +1340,33 @@ bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
1326
1340
if (!ST.hasVInstructions ())
1327
1341
return false ;
1328
1342
1329
- SetVector<MachineInstr *> Worklist;
1330
- auto PushOperands = [this , &Worklist](MachineInstr &MI,
1331
- bool IgnoreSameBlock) {
1332
- for (auto &Op : MI.operands ()) {
1333
- if (!Op.isReg () || !Op.isUse () || !Op.getReg ().isVirtual () ||
1334
- !isVectorRegClass (Op.getReg (), MRI))
1335
- continue ;
1336
-
1337
- MachineInstr *DefMI = MRI->getVRegDef (Op.getReg ());
1338
- if (!isCandidate (*DefMI))
1339
- continue ;
1340
-
1341
- if (IgnoreSameBlock && DefMI->getParent () == MI.getParent ())
1342
- continue ;
1343
-
1344
- Worklist.insert (DefMI);
1345
- }
1346
- };
1343
+ TII = ST.getInstrInfo ();
1347
1344
1348
- // Do a first pass eagerly rewriting in roughly reverse instruction
1349
- // order, populate the worklist with any instructions we might need to
1350
- // revisit. We avoid adding definitions to the worklist if they're
1351
- // in the same block - we're about to visit them anyways.
1352
1345
bool MadeChange = false ;
1353
1346
for (MachineBasicBlock &MBB : MF) {
1354
1347
// Avoid unreachable blocks as they have degenerate dominance
1355
1348
if (!MDT->isReachableFromEntry (&MBB))
1356
1349
continue ;
1357
1350
1358
- for (auto &MI : reverse (MBB)) {
1351
+ // For each instruction that defines a vector, compute what VL its
1352
+ // downstream users demand.
1353
+ for (const auto &MI : reverse (MBB)) {
1354
+ if (!isCandidate (MI))
1355
+ continue ;
1356
+ DemandedVLs.insert ({&MI, computeDemandedVL (MI)});
1357
+ }
1358
+
1359
+ // Then go through and see if we can reduce the VL of any instructions to
1360
+ // only what's demanded.
1361
+ for (auto &MI : MBB) {
1359
1362
if (!isCandidate (MI))
1360
1363
continue ;
1361
1364
if (!tryReduceVL (MI))
1362
1365
continue ;
1363
1366
MadeChange = true ;
1364
- PushOperands (MI, /* IgnoreSameBlock*/ true );
1365
1367
}
1366
- }
1367
1368
1368
- while (!Worklist.empty ()) {
1369
- assert (MadeChange);
1370
- MachineInstr &MI = *Worklist.pop_back_val ();
1371
- assert (isCandidate (MI));
1372
- if (!tryReduceVL (MI))
1373
- continue ;
1374
- PushOperands (MI, /* IgnoreSameBlock*/ false );
1369
+ DemandedVLs.clear ();
1375
1370
}
1376
1371
1377
1372
return MadeChange;
0 commit comments