Skip to content

[IndVars] Support shl by constant and or disjoint in getExtendedOperandRecurrence. #84282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 85 additions & 17 deletions llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1381,48 +1381,116 @@ const SCEV *WidenIV::getSCEVByOpCode(const SCEV *LHS, const SCEV *RHS,
};
}

namespace {

// Represents a interesting integer binary operation for
// getExtendedOperandRecurrence. This may be a shl that is being treated as a
// multiply or a 'or disjoint' that is being treated as 'add nsw nuw'.
struct BinaryOp {
unsigned Opcode;
std::array<Value *, 2> Operands;
bool IsNSW = false;
bool IsNUW = false;

explicit BinaryOp(Instruction *Op)
: Opcode(Op->getOpcode()),
Operands({Op->getOperand(0), Op->getOperand(1)}) {
if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
IsNSW = OBO->hasNoSignedWrap();
IsNUW = OBO->hasNoUnsignedWrap();
}
}

explicit BinaryOp(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS,
bool IsNSW = false, bool IsNUW = false)
: Opcode(Opcode), Operands({LHS, RHS}), IsNSW(IsNSW), IsNUW(IsNUW) {}
};

} // end anonymous namespace

static std::optional<BinaryOp> matchBinaryOp(Instruction *Op) {
switch (Op->getOpcode()) {
case Instruction::Add:
case Instruction::Sub:
case Instruction::Mul:
return BinaryOp(Op);
case Instruction::Or: {
// Convert or disjoint into add nuw nsw.
if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
/*IsNSW=*/true, /*IsNUW=*/true);
break;
}
case Instruction::Shl: {
if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
unsigned BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();

// If the shift count is not less than the bitwidth, the result of
// the shift is undefined. Don't try to analyze it, because the
// resolution chosen here may differ from the resolution chosen in
// other parts of the compiler.
if (SA->getValue().ult(BitWidth)) {
// We can safely preserve the nuw flag in all cases. It's also safe to
// turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
// requires special handling. It can be preserved as long as we're not
// left shifting by bitwidth - 1.
bool IsNUW = Op->hasNoUnsignedWrap();
bool IsNSW = Op->hasNoSignedWrap() &&
(IsNUW || SA->getValue().ult(BitWidth - 1));

ConstantInt *X =
ConstantInt::get(Op->getContext(),
APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
return BinaryOp(Instruction::Mul, Op->getOperand(0), X, IsNSW, IsNUW);
}
}

break;
}
}

return std::nullopt;
}

/// No-wrap operations can transfer sign extension of their result to their
/// operands. Generate the SCEV value for the widened operation without
/// actually modifying the IR yet. If the expression after extending the
/// operands is an AddRec for this loop, return the AddRec and the kind of
/// extension used.
WidenIV::WidenedRecTy
WidenIV::getExtendedOperandRecurrence(WidenIV::NarrowIVDefUse DU) {
// Handle the common case of add<nsw/nuw>
const unsigned OpCode = DU.NarrowUse->getOpcode();
// Only Add/Sub/Mul instructions supported yet.
if (OpCode != Instruction::Add && OpCode != Instruction::Sub &&
OpCode != Instruction::Mul)
auto Op = matchBinaryOp(DU.NarrowUse);
if (!Op)
return {nullptr, ExtendKind::Unknown};

assert((Op->Opcode == Instruction::Add || Op->Opcode == Instruction::Sub ||
Op->Opcode == Instruction::Mul) &&
"Unexpected opcode");

// One operand (NarrowDef) has already been extended to WideDef. Now determine
// if extending the other will lead to a recurrence.
const unsigned ExtendOperIdx =
DU.NarrowUse->getOperand(0) == DU.NarrowDef ? 1 : 0;
assert(DU.NarrowUse->getOperand(1-ExtendOperIdx) == DU.NarrowDef && "bad DU");
const unsigned ExtendOperIdx = Op->Operands[0] == DU.NarrowDef ? 1 : 0;
assert(Op->Operands[1 - ExtendOperIdx] == DU.NarrowDef && "bad DU");

const OverflowingBinaryOperator *OBO =
cast<OverflowingBinaryOperator>(DU.NarrowUse);
ExtendKind ExtKind = getExtendKind(DU.NarrowDef);
if (!(ExtKind == ExtendKind::Sign && OBO->hasNoSignedWrap()) &&
!(ExtKind == ExtendKind::Zero && OBO->hasNoUnsignedWrap())) {
if (!(ExtKind == ExtendKind::Sign && Op->IsNSW) &&
!(ExtKind == ExtendKind::Zero && Op->IsNUW)) {
ExtKind = ExtendKind::Unknown;

// For a non-negative NarrowDef, we can choose either type of
// extension. We want to use the current extend kind if legal
// (see above), and we only hit this code if we need to check
// the opposite case.
if (DU.NeverNegative) {
if (OBO->hasNoSignedWrap()) {
if (Op->IsNSW) {
ExtKind = ExtendKind::Sign;
} else if (OBO->hasNoUnsignedWrap()) {
} else if (Op->IsNUW) {
ExtKind = ExtendKind::Zero;
}
}
}

const SCEV *ExtendOperExpr =
SE->getSCEV(DU.NarrowUse->getOperand(ExtendOperIdx));
const SCEV *ExtendOperExpr = SE->getSCEV(Op->Operands[ExtendOperIdx]);
if (ExtKind == ExtendKind::Sign)
ExtendOperExpr = SE->getSignExtendExpr(ExtendOperExpr, WideType);
else if (ExtKind == ExtendKind::Zero)
Expand All @@ -1443,7 +1511,7 @@ WidenIV::getExtendedOperandRecurrence(WidenIV::NarrowIVDefUse DU) {
if (ExtendOperIdx == 0)
std::swap(lhs, rhs);
const SCEVAddRecExpr *AddRec =
dyn_cast<SCEVAddRecExpr>(getSCEVByOpCode(lhs, rhs, OpCode));
dyn_cast<SCEVAddRecExpr>(getSCEVByOpCode(lhs, rhs, Op->Opcode));

if (!AddRec || AddRec->getLoop() != L)
return {nullptr, ExtendKind::Unknown};
Expand Down
52 changes: 52 additions & 0 deletions llvm/test/Transforms/IndVarSimplify/iv-widen-elim-ext.ll
Original file line number Diff line number Diff line change
Expand Up @@ -493,3 +493,55 @@ for.body: ; preds = %for.body.lr.ph, %fo
%cmp = icmp ult i32 %add, %length
br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit
}

; Test that we can handle shl and disjoint or in getExtendedOperandRecurrence.
define void @foo7(i32 %n, ptr %a, i32 %x) {
; CHECK-LABEL: @foo7(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[CMP6:%.*]] = icmp sgt i32 [[N:%.*]], 0
; CHECK-NEXT: br i1 [[CMP6]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_COND_CLEANUP:%.*]]
; CHECK: for.body.lr.ph:
; CHECK-NEXT: [[ADD1:%.*]] = add nsw i32 [[X:%.*]], 2
; CHECK-NEXT: [[TMP0:%.*]] = sext i32 [[ADD1]] to i64
; CHECK-NEXT: [[TMP1:%.*]] = sext i32 [[N]] to i64
; CHECK-NEXT: br label [[FOR_BODY:%.*]]
; CHECK: for.cond.cleanup.loopexit:
; CHECK-NEXT: br label [[FOR_COND_CLEANUP]]
; CHECK: for.cond.cleanup:
; CHECK-NEXT: ret void
; CHECK: for.body:
; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ], [ 0, [[FOR_BODY_LR_PH]] ]
; CHECK-NEXT: [[TMP2:%.*]] = shl nsw i64 [[INDVARS_IV]], 1
; CHECK-NEXT: [[TMP3:%.*]] = or disjoint i64 [[TMP2]], 1
; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[TMP3]]
; CHECK-NEXT: [[TMP4:%.*]] = trunc i64 [[INDVARS_IV]] to i32
; CHECK-NEXT: store i32 [[TMP4]], ptr [[ARRAYIDX]], align 4
; CHECK-NEXT: [[INDVARS_IV_NEXT]] = add nsw i64 [[INDVARS_IV]], [[TMP0]]
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i64 [[INDVARS_IV_NEXT]], [[TMP1]]
; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]]
;
entry:
%cmp6 = icmp sgt i32 %n, 0
br i1 %cmp6, label %for.body.lr.ph, label %for.cond.cleanup

for.body.lr.ph: ; preds = %entry
%add1 = add nsw i32 %x, 2
br label %for.body

for.cond.cleanup.loopexit: ; preds = %for.body
br label %for.cond.cleanup

for.cond.cleanup: ; preds = %for.cond.cleanup.loopexit, %entry
ret void

for.body: ; preds = %for.body.lr.ph, %for.body
%i.07 = phi i32 [ 0, %for.body.lr.ph ], [ %add2, %for.body ]
%mul = shl nsw i32 %i.07, 1
%add = or disjoint i32 %mul, 1
%idxprom = sext i32 %add to i64
%arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom
store i32 %i.07, ptr %arrayidx, align 4
%add2 = add nsw i32 %add1, %i.07
%cmp = icmp slt i32 %add2, %n
br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit
}