@@ -1381,48 +1381,116 @@ const SCEV *WidenIV::getSCEVByOpCode(const SCEV *LHS, const SCEV *RHS,
1381
1381
};
1382
1382
}
1383
1383
1384
+ namespace {
1385
+
1386
+ // Represents a interesting integer binary operation for
1387
+ // getExtendedOperandRecurrence. This may be a shl that is being treated as a
1388
+ // multiply or a 'or disjoint' that is being treated as 'add nsw nuw'.
1389
+ struct BinaryOp {
1390
+ unsigned Opcode;
1391
+ std::array<Value *, 2 > Operands;
1392
+ bool IsNSW = false ;
1393
+ bool IsNUW = false ;
1394
+
1395
+ explicit BinaryOp (Instruction *Op)
1396
+ : Opcode(Op->getOpcode ()),
1397
+ Operands({Op->getOperand (0 ), Op->getOperand (1 )}) {
1398
+ if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
1399
+ IsNSW = OBO->hasNoSignedWrap ();
1400
+ IsNUW = OBO->hasNoUnsignedWrap ();
1401
+ }
1402
+ }
1403
+
1404
+ explicit BinaryOp (Instruction::BinaryOps Opcode, Value *LHS, Value *RHS,
1405
+ bool IsNSW = false , bool IsNUW = false )
1406
+ : Opcode(Opcode), Operands({LHS, RHS}), IsNSW(IsNSW), IsNUW(IsNUW) {}
1407
+ };
1408
+
1409
+ } // end anonymous namespace
1410
+
1411
+ static std::optional<BinaryOp> matchBinaryOp (Instruction *Op) {
1412
+ switch (Op->getOpcode ()) {
1413
+ case Instruction::Add:
1414
+ case Instruction::Sub:
1415
+ case Instruction::Mul:
1416
+ return BinaryOp (Op);
1417
+ case Instruction::Or: {
1418
+ // Convert or disjoint into add nuw nsw.
1419
+ if (cast<PossiblyDisjointInst>(Op)->isDisjoint ())
1420
+ return BinaryOp (Instruction::Add, Op->getOperand (0 ), Op->getOperand (1 ),
1421
+ /* IsNSW=*/ true , /* IsNUW=*/ true );
1422
+ break ;
1423
+ }
1424
+ case Instruction::Shl: {
1425
+ if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand (1 ))) {
1426
+ unsigned BitWidth = cast<IntegerType>(SA->getType ())->getBitWidth ();
1427
+
1428
+ // If the shift count is not less than the bitwidth, the result of
1429
+ // the shift is undefined. Don't try to analyze it, because the
1430
+ // resolution chosen here may differ from the resolution chosen in
1431
+ // other parts of the compiler.
1432
+ if (SA->getValue ().ult (BitWidth)) {
1433
+ // We can safely preserve the nuw flag in all cases. It's also safe to
1434
+ // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
1435
+ // requires special handling. It can be preserved as long as we're not
1436
+ // left shifting by bitwidth - 1.
1437
+ bool IsNUW = Op->hasNoUnsignedWrap ();
1438
+ bool IsNSW = Op->hasNoSignedWrap () &&
1439
+ (IsNUW || SA->getValue ().ult (BitWidth - 1 ));
1440
+
1441
+ ConstantInt *X =
1442
+ ConstantInt::get (Op->getContext (),
1443
+ APInt::getOneBitSet (BitWidth, SA->getZExtValue ()));
1444
+ return BinaryOp (Instruction::Mul, Op->getOperand (0 ), X, IsNSW, IsNUW);
1445
+ }
1446
+ }
1447
+
1448
+ break ;
1449
+ }
1450
+ }
1451
+
1452
+ return std::nullopt;
1453
+ }
1454
+
1384
1455
// / No-wrap operations can transfer sign extension of their result to their
1385
1456
// / operands. Generate the SCEV value for the widened operation without
1386
1457
// / actually modifying the IR yet. If the expression after extending the
1387
1458
// / operands is an AddRec for this loop, return the AddRec and the kind of
1388
1459
// / extension used.
1389
1460
WidenIV::WidenedRecTy
1390
1461
WidenIV::getExtendedOperandRecurrence (WidenIV::NarrowIVDefUse DU) {
1391
- // Handle the common case of add<nsw/nuw>
1392
- const unsigned OpCode = DU.NarrowUse ->getOpcode ();
1393
- // Only Add/Sub/Mul instructions supported yet.
1394
- if (OpCode != Instruction::Add && OpCode != Instruction::Sub &&
1395
- OpCode != Instruction::Mul)
1462
+ auto Op = matchBinaryOp (DU.NarrowUse );
1463
+ if (!Op)
1396
1464
return {nullptr , ExtendKind::Unknown};
1397
1465
1466
+ assert ((Op->Opcode == Instruction::Add || Op->Opcode == Instruction::Sub ||
1467
+ Op->Opcode == Instruction::Mul) &&
1468
+ " Unexpected opcode" );
1469
+
1398
1470
// One operand (NarrowDef) has already been extended to WideDef. Now determine
1399
1471
// if extending the other will lead to a recurrence.
1400
- const unsigned ExtendOperIdx =
1401
- DU.NarrowUse ->getOperand (0 ) == DU.NarrowDef ? 1 : 0 ;
1402
- assert (DU.NarrowUse ->getOperand (1 -ExtendOperIdx) == DU.NarrowDef && " bad DU" );
1472
+ const unsigned ExtendOperIdx = Op->Operands [0 ] == DU.NarrowDef ? 1 : 0 ;
1473
+ assert (Op->Operands [1 - ExtendOperIdx] == DU.NarrowDef && " bad DU" );
1403
1474
1404
- const OverflowingBinaryOperator *OBO =
1405
- cast<OverflowingBinaryOperator>(DU.NarrowUse );
1406
1475
ExtendKind ExtKind = getExtendKind (DU.NarrowDef );
1407
- if (!(ExtKind == ExtendKind::Sign && OBO-> hasNoSignedWrap () ) &&
1408
- !(ExtKind == ExtendKind::Zero && OBO-> hasNoUnsignedWrap () )) {
1476
+ if (!(ExtKind == ExtendKind::Sign && Op-> IsNSW ) &&
1477
+ !(ExtKind == ExtendKind::Zero && Op-> IsNUW )) {
1409
1478
ExtKind = ExtendKind::Unknown;
1410
1479
1411
1480
// For a non-negative NarrowDef, we can choose either type of
1412
1481
// extension. We want to use the current extend kind if legal
1413
1482
// (see above), and we only hit this code if we need to check
1414
1483
// the opposite case.
1415
1484
if (DU.NeverNegative ) {
1416
- if (OBO-> hasNoSignedWrap () ) {
1485
+ if (Op-> IsNSW ) {
1417
1486
ExtKind = ExtendKind::Sign;
1418
- } else if (OBO-> hasNoUnsignedWrap () ) {
1487
+ } else if (Op-> IsNUW ) {
1419
1488
ExtKind = ExtendKind::Zero;
1420
1489
}
1421
1490
}
1422
1491
}
1423
1492
1424
- const SCEV *ExtendOperExpr =
1425
- SE->getSCEV (DU.NarrowUse ->getOperand (ExtendOperIdx));
1493
+ const SCEV *ExtendOperExpr = SE->getSCEV (Op->Operands [ExtendOperIdx]);
1426
1494
if (ExtKind == ExtendKind::Sign)
1427
1495
ExtendOperExpr = SE->getSignExtendExpr (ExtendOperExpr, WideType);
1428
1496
else if (ExtKind == ExtendKind::Zero)
@@ -1443,7 +1511,7 @@ WidenIV::getExtendedOperandRecurrence(WidenIV::NarrowIVDefUse DU) {
1443
1511
if (ExtendOperIdx == 0 )
1444
1512
std::swap (lhs, rhs);
1445
1513
const SCEVAddRecExpr *AddRec =
1446
- dyn_cast<SCEVAddRecExpr>(getSCEVByOpCode (lhs, rhs, OpCode ));
1514
+ dyn_cast<SCEVAddRecExpr>(getSCEVByOpCode (lhs, rhs, Op-> Opcode ));
1447
1515
1448
1516
if (!AddRec || AddRec->getLoop () != L)
1449
1517
return {nullptr , ExtendKind::Unknown};
0 commit comments