@@ -1014,6 +1014,109 @@ RISCVCC::CondCode RISCVCC::getOppositeBranchCondition(RISCVCC::CondCode CC) {
1014
1014
}
1015
1015
}
1016
1016
1017
+ // Return true if MO definitely contains the value one.
1018
+ static bool isOne (MachineOperand &MO) {
1019
+ if (MO.isImm () && MO.getImm () == 1 )
1020
+ return true ;
1021
+
1022
+ if (!MO.isReg () || !MO.getReg ().isVirtual ())
1023
+ return false ;
1024
+
1025
+ MachineRegisterInfo &MRI =
1026
+ MO.getParent ()->getParent ()->getParent ()->getRegInfo ();
1027
+ MachineInstr *DefMI = MRI.getUniqueVRegDef (MO.getReg ());
1028
+ if (!DefMI)
1029
+ return false ;
1030
+
1031
+ // For now, just check the canonical one value.
1032
+ if (DefMI->getOpcode () == RISCV::ADDI &&
1033
+ DefMI->getOperand (1 ).getReg () == RISCV::X0 &&
1034
+ DefMI->getOperand (2 ).getImm () == 1 )
1035
+ return true ;
1036
+
1037
+ return false ;
1038
+ }
1039
+
1040
+ // Return true if MO definitely contains the value zero.
1041
+ static bool isZero (MachineOperand &MO) {
1042
+ if (MO.isImm () && MO.getImm () == 0 )
1043
+ return true ;
1044
+ if (MO.isReg () && MO.getReg () == RISCV::X0)
1045
+ return true ;
1046
+ return false ;
1047
+ }
1048
+
1049
+ bool RISCVInstrInfo::trySimplifyCondBr (
1050
+ MachineBasicBlock &MBB, MachineBasicBlock *TBB, MachineBasicBlock *FBB,
1051
+ SmallVectorImpl<MachineOperand> &Cond) const {
1052
+
1053
+ if (!TBB || Cond.size () != 3 )
1054
+ return false ;
1055
+
1056
+ RISCVCC::CondCode CC = static_cast <RISCVCC::CondCode>(Cond[0 ].getImm ());
1057
+ auto LHS = Cond[1 ];
1058
+ auto RHS = Cond[2 ];
1059
+
1060
+ MachineBasicBlock *Folded = nullptr ;
1061
+ switch (CC) {
1062
+ default :
1063
+ // TODO: Implement for more CCs
1064
+ return false ;
1065
+ case RISCVCC::COND_EQ: {
1066
+ // We can statically evaluate that we take the first branch
1067
+ if ((isZero (LHS) && isZero (RHS)) || (isOne (LHS) && isOne (RHS))) {
1068
+ Folded = TBB;
1069
+ break ;
1070
+ }
1071
+ // We can statically evaluate that we take the second branch
1072
+ if ((isZero (LHS) && isOne (RHS)) || (isOne (LHS) && isZero (RHS))) {
1073
+ Folded = FBB;
1074
+ break ;
1075
+ }
1076
+ return false ;
1077
+ }
1078
+ case RISCVCC::COND_NE: {
1079
+ // We can statically evaluate that we take the first branch
1080
+ if ((isOne (LHS) && isZero (RHS)) || (isZero (LHS) && isOne (RHS))) {
1081
+ Folded = TBB;
1082
+ break ;
1083
+ }
1084
+ // We can statically evaluate that we take the second branch
1085
+ if ((isZero (LHS) && isZero (RHS)) || (isOne (LHS) && isOne (RHS))) {
1086
+ Folded = FBB;
1087
+ break ;
1088
+ }
1089
+ return false ;
1090
+ }
1091
+ }
1092
+
1093
+ // At this point, its legal to optimize.
1094
+ removeBranch (MBB);
1095
+ Cond.clear ();
1096
+
1097
+ // Only need to insert a branch if we're not falling through.
1098
+ if (Folded) {
1099
+ DebugLoc DL = MBB.findBranchDebugLoc ();
1100
+ insertBranch (MBB, Folded, nullptr , {}, DL);
1101
+ }
1102
+
1103
+ // Update the successors. Remove them all and add back the correct one.
1104
+ while (!MBB.succ_empty ())
1105
+ MBB.removeSuccessor (MBB.succ_end () - 1 );
1106
+
1107
+ // If it's a fallthrough, we need to figure out where MBB is going.
1108
+ if (!Folded) {
1109
+ MachineFunction::iterator Fallthrough = ++MBB.getIterator ();
1110
+ if (Fallthrough != MBB.getParent ()->end ())
1111
+ MBB.addSuccessor (&*Fallthrough);
1112
+ } else
1113
+ MBB.addSuccessor (Folded);
1114
+
1115
+ TBB = Folded;
1116
+ FBB = nullptr ;
1117
+ return true ;
1118
+ }
1119
+
1017
1120
bool RISCVInstrInfo::analyzeBranch (MachineBasicBlock &MBB,
1018
1121
MachineBasicBlock *&TBB,
1019
1122
MachineBasicBlock *&FBB,
@@ -1071,12 +1174,9 @@ bool RISCVInstrInfo::analyzeBranch(MachineBasicBlock &MBB,
1071
1174
// Handle a single conditional branch.
1072
1175
if (NumTerminators == 1 && I->getDesc ().isConditionalBranch ()) {
1073
1176
parseCondBranch (*I, TBB, Cond);
1074
- // Try and optimize the conditional branch.
1075
- if (AllowModify) {
1076
- optimizeCondBranch (*I);
1077
- // The branch might have changed, reanalyze it.
1078
- return analyzeBranch (MBB, TBB, FBB, Cond, false );
1079
- }
1177
+ // Try to fold the branch of the conditional branch into a the fallthru.
1178
+ if (AllowModify)
1179
+ trySimplifyCondBr (MBB, TBB, FBB, Cond);
1080
1180
return false ;
1081
1181
}
1082
1182
@@ -1085,14 +1185,10 @@ bool RISCVInstrInfo::analyzeBranch(MachineBasicBlock &MBB,
1085
1185
I->getDesc ().isUnconditionalBranch ()) {
1086
1186
parseCondBranch (*std::prev (I), TBB, Cond);
1087
1187
FBB = getBranchDestBlock (*I);
1088
- // Try and optimize the pair.
1089
- if (AllowModify) {
1090
- if (optimizeCondBranch (*std::prev (I)))
1091
- I->eraseFromParent ();
1092
-
1093
- // The branch might have changed, reanalyze it.
1094
- return analyzeBranch (MBB, TBB, FBB, Cond, false );
1095
- }
1188
+ // Try to fold the branch of the conditional branch into an unconditional
1189
+ // branch.
1190
+ if (AllowModify)
1191
+ trySimplifyCondBr (MBB, TBB, FBB, Cond);
1096
1192
return false ;
1097
1193
}
1098
1194
@@ -1248,8 +1344,7 @@ bool RISCVInstrInfo::reverseBranchCondition(
1248
1344
1249
1345
bool RISCVInstrInfo::optimizeCondBranch (MachineInstr &MI) const {
1250
1346
MachineBasicBlock *MBB = MI.getParent ();
1251
- if (!MBB)
1252
- return false ;
1347
+ MachineRegisterInfo &MRI = MBB->getParent ()->getRegInfo ();
1253
1348
1254
1349
MachineBasicBlock *TBB, *FBB;
1255
1350
SmallVector<MachineOperand, 3 > Cond;
@@ -1259,97 +1354,8 @@ bool RISCVInstrInfo::optimizeCondBranch(MachineInstr &MI) const {
1259
1354
RISCVCC::CondCode CC = static_cast <RISCVCC::CondCode>(Cond[0 ].getImm ());
1260
1355
assert (CC != RISCVCC::COND_INVALID);
1261
1356
1262
- // Right now we only care about LI (i.e. ADDI x0, imm)
1263
- auto isLoadImm = [](const MachineInstr *MI, int64_t &Imm) -> bool {
1264
- if (MI->getOpcode () == RISCV::ADDI && MI->getOperand (1 ).isReg () &&
1265
- MI->getOperand (1 ).getReg () == RISCV::X0) {
1266
- Imm = MI->getOperand (2 ).getImm ();
1267
- return true ;
1268
- }
1357
+ if (CC == RISCVCC::COND_EQ || CC == RISCVCC::COND_NE)
1269
1358
return false ;
1270
- };
1271
-
1272
- MachineRegisterInfo &MRI = MBB->getParent ()->getRegInfo ();
1273
- // Either a load from immediate instruction or X0.
1274
- auto isFromLoadImm = [&](const MachineOperand &Op, int64_t &Imm) -> bool {
1275
- if (!Op.isReg ())
1276
- return false ;
1277
- Register Reg = Op.getReg ();
1278
- if (Reg == RISCV::X0) {
1279
- Imm = 0 ;
1280
- return true ;
1281
- }
1282
- return Reg.isVirtual () && isLoadImm (MRI.getVRegDef (Reg), Imm);
1283
- };
1284
-
1285
- // Try and convert a conditional branch that can be evaluated statically
1286
- // into an unconditional branch.
1287
- MachineBasicBlock *Folded = nullptr ;
1288
- int64_t C0, C1;
1289
- if (isFromLoadImm (Cond[1 ], C0) && isFromLoadImm (Cond[2 ], C1)) {
1290
- switch (CC) {
1291
- case RISCVCC::COND_INVALID:
1292
- llvm_unreachable (" Unexpected CC" );
1293
- case RISCVCC::COND_EQ: {
1294
- Folded = (C0 == C1) ? TBB : FBB;
1295
- break ;
1296
- }
1297
- case RISCVCC::COND_NE: {
1298
- Folded = (C0 != C1) ? TBB : FBB;
1299
- break ;
1300
- }
1301
- case RISCVCC::COND_LT: {
1302
- Folded = (C0 < C1) ? TBB : FBB;
1303
- break ;
1304
- }
1305
- case RISCVCC::COND_GE: {
1306
- Folded = (C0 >= C1) ? TBB : FBB;
1307
- break ;
1308
- }
1309
- case RISCVCC::COND_LTU: {
1310
- Folded = ((uint64_t )C0 < (uint64_t )C1) ? TBB : FBB;
1311
- break ;
1312
- }
1313
- case RISCVCC::COND_GEU: {
1314
- Folded = ((uint64_t )C0 >= (uint64_t )C1) ? TBB : FBB;
1315
- break ;
1316
- }
1317
- }
1318
-
1319
- // Do the conversion
1320
- // Build the new unconditional branch
1321
- DebugLoc DL = MBB->findBranchDebugLoc ();
1322
- if (Folded) {
1323
- BuildMI (*MBB, MI, DL, get (RISCV::PseudoBR)).addMBB (Folded);
1324
- } else {
1325
- MachineFunction::iterator Fallthrough = ++MBB->getIterator ();
1326
- if (Fallthrough == MBB->getParent ()->end ())
1327
- return false ;
1328
- BuildMI (*MBB, MI, DL, get (RISCV::PseudoBR)).addMBB (&*Fallthrough);
1329
- }
1330
-
1331
- // Update successors of MBB.
1332
- if (Folded == TBB) {
1333
- // If we're taking TBB, then the succ to delete is the fallthrough (if
1334
- // it was a succ in the first place), or its the MBB from the
1335
- // unconditional branch.
1336
- if (!FBB) {
1337
- MachineFunction::iterator Fallthrough = ++MBB->getIterator ();
1338
- if (Fallthrough != MBB->getParent ()->end () &&
1339
- MBB->isSuccessor (&*Fallthrough))
1340
- MBB->removeSuccessor (&*Fallthrough, true );
1341
- } else {
1342
- MBB->removeSuccessor (FBB, true );
1343
- }
1344
- } else if (Folded == FBB) {
1345
- // If we're taking the fallthrough or unconditional branch, then the
1346
- // succ to remove is the one from the conditional branch.
1347
- MBB->removeSuccessor (TBB, true );
1348
- }
1349
-
1350
- MI.eraseFromParent ();
1351
- return true ;
1352
- }
1353
1359
1354
1360
// For two constants C0 and C1 from
1355
1361
// ```
@@ -1368,6 +1374,24 @@ bool RISCVInstrInfo::optimizeCondBranch(MachineInstr &MI) const {
1368
1374
//
1369
1375
// To make sure this optimization is really beneficial, we only
1370
1376
// optimize for cases where Y had only one use (i.e. only used by the branch).
1377
+
1378
+ // Right now we only care about LI (i.e. ADDI x0, imm)
1379
+ auto isLoadImm = [](const MachineInstr *MI, int64_t &Imm) -> bool {
1380
+ if (MI->getOpcode () == RISCV::ADDI && MI->getOperand (1 ).isReg () &&
1381
+ MI->getOperand (1 ).getReg () == RISCV::X0) {
1382
+ Imm = MI->getOperand (2 ).getImm ();
1383
+ return true ;
1384
+ }
1385
+ return false ;
1386
+ };
1387
+ // Either a load from immediate instruction or X0.
1388
+ auto isFromLoadImm = [&](const MachineOperand &Op, int64_t &Imm) -> bool {
1389
+ if (!Op.isReg ())
1390
+ return false ;
1391
+ Register Reg = Op.getReg ();
1392
+ return Reg.isVirtual () && isLoadImm (MRI.getVRegDef (Reg), Imm);
1393
+ };
1394
+
1371
1395
MachineOperand &LHS = MI.getOperand (0 );
1372
1396
MachineOperand &RHS = MI.getOperand (1 );
1373
1397
// Try to find the register for constant Z; return
@@ -1386,6 +1410,7 @@ bool RISCVInstrInfo::optimizeCondBranch(MachineInstr &MI) const {
1386
1410
};
1387
1411
1388
1412
bool Modify = false ;
1413
+ int64_t C0;
1389
1414
if (isFromLoadImm (LHS, C0) && MRI.hasOneUse (LHS.getReg ())) {
1390
1415
// Might be case 1.
1391
1416
// Signed integer overflow is UB. (UINT64_MAX is bigger so we don't need
0 commit comments