Skip to content

Commit 2595931

Browse files
authored
[IndVars] Support shl by constant and or disjoint in getExtendedOperandRecurrence. (#84282)
We can treat a shift by constant as a multiply by a power of 2 and we can treat an or disjoint as a 'add nsw nuw'. I've added a helper struct similar to a struct used in ScalarEvolution.cpp to represent the opcode, operands, and NSW/NUW flags for normal add/sub/mul and shl/or that are being treated as mul/add. I don't think we need to teach cloneIVUser about this. It will continue to clone them using cloneBitwiseIVUser. After the cloning we will ask for the SCEV expression for the cloned IV user and verify that it matches the AddRec returned by getExtendedOperandRecurrence. Since SCEV also knows how to convert shl to mul and or disjoint to add nsw nuw, this should usually match. If it doesn't match, the cloned IV user will be deleted.
1 parent ffa2810 commit 2595931

File tree

2 files changed

+137
-17
lines changed

2 files changed

+137
-17
lines changed

llvm/lib/Transforms/Utils/SimplifyIndVar.cpp

Lines changed: 85 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,48 +1381,116 @@ const SCEV *WidenIV::getSCEVByOpCode(const SCEV *LHS, const SCEV *RHS,
13811381
};
13821382
}
13831383

1384+
namespace {
1385+
1386+
// Represents a interesting integer binary operation for
1387+
// getExtendedOperandRecurrence. This may be a shl that is being treated as a
1388+
// multiply or a 'or disjoint' that is being treated as 'add nsw nuw'.
1389+
struct BinaryOp {
1390+
unsigned Opcode;
1391+
std::array<Value *, 2> Operands;
1392+
bool IsNSW = false;
1393+
bool IsNUW = false;
1394+
1395+
explicit BinaryOp(Instruction *Op)
1396+
: Opcode(Op->getOpcode()),
1397+
Operands({Op->getOperand(0), Op->getOperand(1)}) {
1398+
if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
1399+
IsNSW = OBO->hasNoSignedWrap();
1400+
IsNUW = OBO->hasNoUnsignedWrap();
1401+
}
1402+
}
1403+
1404+
explicit BinaryOp(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS,
1405+
bool IsNSW = false, bool IsNUW = false)
1406+
: Opcode(Opcode), Operands({LHS, RHS}), IsNSW(IsNSW), IsNUW(IsNUW) {}
1407+
};
1408+
1409+
} // end anonymous namespace
1410+
1411+
static std::optional<BinaryOp> matchBinaryOp(Instruction *Op) {
1412+
switch (Op->getOpcode()) {
1413+
case Instruction::Add:
1414+
case Instruction::Sub:
1415+
case Instruction::Mul:
1416+
return BinaryOp(Op);
1417+
case Instruction::Or: {
1418+
// Convert or disjoint into add nuw nsw.
1419+
if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
1420+
return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
1421+
/*IsNSW=*/true, /*IsNUW=*/true);
1422+
break;
1423+
}
1424+
case Instruction::Shl: {
1425+
if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
1426+
unsigned BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
1427+
1428+
// If the shift count is not less than the bitwidth, the result of
1429+
// the shift is undefined. Don't try to analyze it, because the
1430+
// resolution chosen here may differ from the resolution chosen in
1431+
// other parts of the compiler.
1432+
if (SA->getValue().ult(BitWidth)) {
1433+
// We can safely preserve the nuw flag in all cases. It's also safe to
1434+
// turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
1435+
// requires special handling. It can be preserved as long as we're not
1436+
// left shifting by bitwidth - 1.
1437+
bool IsNUW = Op->hasNoUnsignedWrap();
1438+
bool IsNSW = Op->hasNoSignedWrap() &&
1439+
(IsNUW || SA->getValue().ult(BitWidth - 1));
1440+
1441+
ConstantInt *X =
1442+
ConstantInt::get(Op->getContext(),
1443+
APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
1444+
return BinaryOp(Instruction::Mul, Op->getOperand(0), X, IsNSW, IsNUW);
1445+
}
1446+
}
1447+
1448+
break;
1449+
}
1450+
}
1451+
1452+
return std::nullopt;
1453+
}
1454+
13841455
/// No-wrap operations can transfer sign extension of their result to their
13851456
/// operands. Generate the SCEV value for the widened operation without
13861457
/// actually modifying the IR yet. If the expression after extending the
13871458
/// operands is an AddRec for this loop, return the AddRec and the kind of
13881459
/// extension used.
13891460
WidenIV::WidenedRecTy
13901461
WidenIV::getExtendedOperandRecurrence(WidenIV::NarrowIVDefUse DU) {
1391-
// Handle the common case of add<nsw/nuw>
1392-
const unsigned OpCode = DU.NarrowUse->getOpcode();
1393-
// Only Add/Sub/Mul instructions supported yet.
1394-
if (OpCode != Instruction::Add && OpCode != Instruction::Sub &&
1395-
OpCode != Instruction::Mul)
1462+
auto Op = matchBinaryOp(DU.NarrowUse);
1463+
if (!Op)
13961464
return {nullptr, ExtendKind::Unknown};
13971465

1466+
assert((Op->Opcode == Instruction::Add || Op->Opcode == Instruction::Sub ||
1467+
Op->Opcode == Instruction::Mul) &&
1468+
"Unexpected opcode");
1469+
13981470
// One operand (NarrowDef) has already been extended to WideDef. Now determine
13991471
// if extending the other will lead to a recurrence.
1400-
const unsigned ExtendOperIdx =
1401-
DU.NarrowUse->getOperand(0) == DU.NarrowDef ? 1 : 0;
1402-
assert(DU.NarrowUse->getOperand(1-ExtendOperIdx) == DU.NarrowDef && "bad DU");
1472+
const unsigned ExtendOperIdx = Op->Operands[0] == DU.NarrowDef ? 1 : 0;
1473+
assert(Op->Operands[1 - ExtendOperIdx] == DU.NarrowDef && "bad DU");
14031474

1404-
const OverflowingBinaryOperator *OBO =
1405-
cast<OverflowingBinaryOperator>(DU.NarrowUse);
14061475
ExtendKind ExtKind = getExtendKind(DU.NarrowDef);
1407-
if (!(ExtKind == ExtendKind::Sign && OBO->hasNoSignedWrap()) &&
1408-
!(ExtKind == ExtendKind::Zero && OBO->hasNoUnsignedWrap())) {
1476+
if (!(ExtKind == ExtendKind::Sign && Op->IsNSW) &&
1477+
!(ExtKind == ExtendKind::Zero && Op->IsNUW)) {
14091478
ExtKind = ExtendKind::Unknown;
14101479

14111480
// For a non-negative NarrowDef, we can choose either type of
14121481
// extension. We want to use the current extend kind if legal
14131482
// (see above), and we only hit this code if we need to check
14141483
// the opposite case.
14151484
if (DU.NeverNegative) {
1416-
if (OBO->hasNoSignedWrap()) {
1485+
if (Op->IsNSW) {
14171486
ExtKind = ExtendKind::Sign;
1418-
} else if (OBO->hasNoUnsignedWrap()) {
1487+
} else if (Op->IsNUW) {
14191488
ExtKind = ExtendKind::Zero;
14201489
}
14211490
}
14221491
}
14231492

1424-
const SCEV *ExtendOperExpr =
1425-
SE->getSCEV(DU.NarrowUse->getOperand(ExtendOperIdx));
1493+
const SCEV *ExtendOperExpr = SE->getSCEV(Op->Operands[ExtendOperIdx]);
14261494
if (ExtKind == ExtendKind::Sign)
14271495
ExtendOperExpr = SE->getSignExtendExpr(ExtendOperExpr, WideType);
14281496
else if (ExtKind == ExtendKind::Zero)
@@ -1443,7 +1511,7 @@ WidenIV::getExtendedOperandRecurrence(WidenIV::NarrowIVDefUse DU) {
14431511
if (ExtendOperIdx == 0)
14441512
std::swap(lhs, rhs);
14451513
const SCEVAddRecExpr *AddRec =
1446-
dyn_cast<SCEVAddRecExpr>(getSCEVByOpCode(lhs, rhs, OpCode));
1514+
dyn_cast<SCEVAddRecExpr>(getSCEVByOpCode(lhs, rhs, Op->Opcode));
14471515

14481516
if (!AddRec || AddRec->getLoop() != L)
14491517
return {nullptr, ExtendKind::Unknown};

llvm/test/Transforms/IndVarSimplify/iv-widen-elim-ext.ll

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,3 +493,55 @@ for.body: ; preds = %for.body.lr.ph, %fo
493493
%cmp = icmp ult i32 %add, %length
494494
br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit
495495
}
496+
497+
; Test that we can handle shl and disjoint or in getExtendedOperandRecurrence.
498+
define void @foo7(i32 %n, ptr %a, i32 %x) {
499+
; CHECK-LABEL: @foo7(
500+
; CHECK-NEXT: entry:
501+
; CHECK-NEXT: [[CMP6:%.*]] = icmp sgt i32 [[N:%.*]], 0
502+
; CHECK-NEXT: br i1 [[CMP6]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_COND_CLEANUP:%.*]]
503+
; CHECK: for.body.lr.ph:
504+
; CHECK-NEXT: [[ADD1:%.*]] = add nsw i32 [[X:%.*]], 2
505+
; CHECK-NEXT: [[TMP0:%.*]] = sext i32 [[ADD1]] to i64
506+
; CHECK-NEXT: [[TMP1:%.*]] = sext i32 [[N]] to i64
507+
; CHECK-NEXT: br label [[FOR_BODY:%.*]]
508+
; CHECK: for.cond.cleanup.loopexit:
509+
; CHECK-NEXT: br label [[FOR_COND_CLEANUP]]
510+
; CHECK: for.cond.cleanup:
511+
; CHECK-NEXT: ret void
512+
; CHECK: for.body:
513+
; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ], [ 0, [[FOR_BODY_LR_PH]] ]
514+
; CHECK-NEXT: [[TMP2:%.*]] = shl nsw i64 [[INDVARS_IV]], 1
515+
; CHECK-NEXT: [[TMP3:%.*]] = or disjoint i64 [[TMP2]], 1
516+
; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[TMP3]]
517+
; CHECK-NEXT: [[TMP4:%.*]] = trunc i64 [[INDVARS_IV]] to i32
518+
; CHECK-NEXT: store i32 [[TMP4]], ptr [[ARRAYIDX]], align 4
519+
; CHECK-NEXT: [[INDVARS_IV_NEXT]] = add nsw i64 [[INDVARS_IV]], [[TMP0]]
520+
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i64 [[INDVARS_IV_NEXT]], [[TMP1]]
521+
; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]]
522+
;
523+
entry:
524+
%cmp6 = icmp sgt i32 %n, 0
525+
br i1 %cmp6, label %for.body.lr.ph, label %for.cond.cleanup
526+
527+
for.body.lr.ph: ; preds = %entry
528+
%add1 = add nsw i32 %x, 2
529+
br label %for.body
530+
531+
for.cond.cleanup.loopexit: ; preds = %for.body
532+
br label %for.cond.cleanup
533+
534+
for.cond.cleanup: ; preds = %for.cond.cleanup.loopexit, %entry
535+
ret void
536+
537+
for.body: ; preds = %for.body.lr.ph, %for.body
538+
%i.07 = phi i32 [ 0, %for.body.lr.ph ], [ %add2, %for.body ]
539+
%mul = shl nsw i32 %i.07, 1
540+
%add = or disjoint i32 %mul, 1
541+
%idxprom = sext i32 %add to i64
542+
%arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom
543+
store i32 %i.07, ptr %arrayidx, align 4
544+
%add2 = add nsw i32 %add1, %i.07
545+
%cmp = icmp slt i32 %add2, %n
546+
br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit
547+
}

0 commit comments

Comments
 (0)