Skip to content

Commit 287294d

Browse files
authored
[ConstraintElim] Do not allow overflows in Decomposition (#140541)
Consider the following case: ``` define i1 @pr140481(i32 %x) { %cond = icmp slt i32 %x, 0 call void @llvm.assume(i1 %cond) %add = add nsw i32 %x, 5001000 %mul1 = mul nsw i32 %add, -5001000 %mul2 = mul nsw i32 %mul1, 5001000 %cmp2 = icmp sgt i32 %mul2, 0 ret i1 %cmp2 } ``` Before this patch, `decompose(%mul2)` returns `-25010001000000 * %x + 4052193514966861312`. Therefore, `%cmp2` will be simplified into true because `%x s< 0 && -25010001000000 * %x + 4052193514966861312 s<= 0` is unsat. It is incorrect since the offset `-25010001000000 * 5001000 -> 4052193514966861312` signed wraps. This patch treats a decomposition as invalid if overflows occur when computing coefficients. Closes #140481.
1 parent f72a8ee commit 287294d

File tree

2 files changed

+95
-48
lines changed

2 files changed

+95
-48
lines changed

llvm/lib/Transforms/Scalar/ConstraintElimination.cpp

Lines changed: 73 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,6 @@ static cl::opt<bool> DumpReproducers(
6464
static int64_t MaxConstraintValue = std::numeric_limits<int64_t>::max();
6565
static int64_t MinSignedConstraintValue = std::numeric_limits<int64_t>::min();
6666

67-
// A helper to multiply 2 signed integers where overflowing is allowed.
68-
static int64_t multiplyWithOverflow(int64_t A, int64_t B) {
69-
int64_t Result;
70-
MulOverflow(A, B, Result);
71-
return Result;
72-
}
73-
74-
// A helper to add 2 signed integers where overflowing is allowed.
75-
static int64_t addWithOverflow(int64_t A, int64_t B) {
76-
int64_t Result;
77-
AddOverflow(A, B, Result);
78-
return Result;
79-
}
80-
8167
static Instruction *getContextInstForUse(Use &U) {
8268
Instruction *UserI = cast<Instruction>(U.getUser());
8369
if (auto *Phi = dyn_cast<PHINode>(UserI))
@@ -366,26 +352,42 @@ struct Decomposition {
366352
Decomposition(int64_t Offset, ArrayRef<DecompEntry> Vars)
367353
: Offset(Offset), Vars(Vars) {}
368354

369-
void add(int64_t OtherOffset) {
370-
Offset = addWithOverflow(Offset, OtherOffset);
355+
/// Add \p OtherOffset and return true if the operation overflows, i.e. the
356+
/// new decomposition is invalid.
357+
[[nodiscard]] bool add(int64_t OtherOffset) {
358+
return AddOverflow(Offset, OtherOffset, Offset);
371359
}
372360

373-
void add(const Decomposition &Other) {
374-
add(Other.Offset);
361+
/// Add \p Other and return true if the operation overflows, i.e. the new
362+
/// decomposition is invalid.
363+
[[nodiscard]] bool add(const Decomposition &Other) {
364+
if (add(Other.Offset))
365+
return true;
375366
append_range(Vars, Other.Vars);
367+
return false;
376368
}
377369

378-
void sub(const Decomposition &Other) {
370+
/// Subtract \p Other and return true if the operation overflows, i.e. the new
371+
/// decomposition is invalid.
372+
[[nodiscard]] bool sub(const Decomposition &Other) {
379373
Decomposition Tmp = Other;
380-
Tmp.mul(-1);
381-
add(Tmp.Offset);
374+
if (Tmp.mul(-1))
375+
return true;
376+
if (add(Tmp.Offset))
377+
return true;
382378
append_range(Vars, Tmp.Vars);
379+
return false;
383380
}
384381

385-
void mul(int64_t Factor) {
386-
Offset = multiplyWithOverflow(Offset, Factor);
382+
/// Multiply all coefficients by \p Factor and return true if the operation
383+
/// overflows, i.e. the new decomposition is invalid.
384+
[[nodiscard]] bool mul(int64_t Factor) {
385+
if (MulOverflow(Offset, Factor, Offset))
386+
return true;
387387
for (auto &Var : Vars)
388-
Var.Coefficient = multiplyWithOverflow(Var.Coefficient, Factor);
388+
if (MulOverflow(Var.Coefficient, Factor, Var.Coefficient))
389+
return true;
390+
return false;
389391
}
390392
};
391393

@@ -467,8 +469,10 @@ static Decomposition decomposeGEP(GEPOperator &GEP,
467469
Decomposition Result(ConstantOffset.getSExtValue(), DecompEntry(1, BasePtr));
468470
for (auto [Index, Scale] : VariableOffsets) {
469471
auto IdxResult = decompose(Index, Preconditions, IsSigned, DL);
470-
IdxResult.mul(Scale.getSExtValue());
471-
Result.add(IdxResult);
472+
if (IdxResult.mul(Scale.getSExtValue()))
473+
return &GEP;
474+
if (Result.add(IdxResult))
475+
return &GEP;
472476

473477
if (!NW.hasNoUnsignedWrap()) {
474478
// Try to prove nuw from nusw and nneg.
@@ -488,11 +492,13 @@ static Decomposition decompose(Value *V,
488492
SmallVectorImpl<ConditionTy> &Preconditions,
489493
bool IsSigned, const DataLayout &DL) {
490494

491-
auto MergeResults = [&Preconditions, IsSigned, &DL](Value *A, Value *B,
492-
bool IsSignedB) {
495+
auto MergeResults = [&Preconditions, IsSigned,
496+
&DL](Value *A, Value *B,
497+
bool IsSignedB) -> std::optional<Decomposition> {
493498
auto ResA = decompose(A, Preconditions, IsSigned, DL);
494499
auto ResB = decompose(B, Preconditions, IsSignedB, DL);
495-
ResA.add(ResB);
500+
if (ResA.add(ResB))
501+
return std::nullopt;
496502
return ResA;
497503
};
498504

@@ -533,21 +539,26 @@ static Decomposition decompose(Value *V,
533539
V = Op0;
534540
}
535541

536-
if (match(V, m_NSWAdd(m_Value(Op0), m_Value(Op1))))
537-
return MergeResults(Op0, Op1, IsSigned);
542+
if (match(V, m_NSWAdd(m_Value(Op0), m_Value(Op1)))) {
543+
if (auto Decomp = MergeResults(Op0, Op1, IsSigned))
544+
return *Decomp;
545+
return {V, IsKnownNonNegative};
546+
}
538547

539548
if (match(V, m_NSWSub(m_Value(Op0), m_Value(Op1)))) {
540549
auto ResA = decompose(Op0, Preconditions, IsSigned, DL);
541550
auto ResB = decompose(Op1, Preconditions, IsSigned, DL);
542-
ResA.sub(ResB);
543-
return ResA;
551+
if (!ResA.sub(ResB))
552+
return ResA;
553+
return {V, IsKnownNonNegative};
544554
}
545555

546556
ConstantInt *CI;
547557
if (match(V, m_NSWMul(m_Value(Op0), m_ConstantInt(CI))) && canUseSExt(CI)) {
548558
auto Result = decompose(Op0, Preconditions, IsSigned, DL);
549-
Result.mul(CI->getSExtValue());
550-
return Result;
559+
if (!Result.mul(CI->getSExtValue()))
560+
return Result;
561+
return {V, IsKnownNonNegative};
551562
}
552563

553564
// (shl nsw x, shift) is (mul nsw x, (1<<shift)), with the exception of
@@ -557,8 +568,9 @@ static Decomposition decompose(Value *V,
557568
if (Shift < Ty->getIntegerBitWidth() - 1) {
558569
assert(Shift < 64 && "Would overflow");
559570
auto Result = decompose(Op0, Preconditions, IsSigned, DL);
560-
Result.mul(int64_t(1) << Shift);
561-
return Result;
571+
if (!Result.mul(int64_t(1) << Shift))
572+
return Result;
573+
return {V, IsKnownNonNegative};
562574
}
563575
}
564576

@@ -593,8 +605,11 @@ static Decomposition decompose(Value *V,
593605
Value *Op1;
594606
ConstantInt *CI;
595607
if (match(V, m_NUWAdd(m_Value(Op0), m_Value(Op1)))) {
596-
return MergeResults(Op0, Op1, IsSigned);
608+
if (auto Decomp = MergeResults(Op0, Op1, IsSigned))
609+
return *Decomp;
610+
return {V, IsKnownNonNegative};
597611
}
612+
598613
if (match(V, m_NSWAdd(m_Value(Op0), m_Value(Op1)))) {
599614
if (!isKnownNonNegative(Op0, DL))
600615
Preconditions.emplace_back(CmpInst::ICMP_SGE, Op0,
@@ -603,41 +618,51 @@ static Decomposition decompose(Value *V,
603618
Preconditions.emplace_back(CmpInst::ICMP_SGE, Op1,
604619
ConstantInt::get(Op1->getType(), 0));
605620

606-
return MergeResults(Op0, Op1, IsSigned);
621+
if (auto Decomp = MergeResults(Op0, Op1, IsSigned))
622+
return *Decomp;
623+
return {V, IsKnownNonNegative};
607624
}
608625

609626
if (match(V, m_Add(m_Value(Op0), m_ConstantInt(CI))) && CI->isNegative() &&
610627
canUseSExt(CI)) {
611628
Preconditions.emplace_back(
612629
CmpInst::ICMP_UGE, Op0,
613630
ConstantInt::get(Op0->getType(), CI->getSExtValue() * -1));
614-
return MergeResults(Op0, CI, true);
631+
if (auto Decomp = MergeResults(Op0, CI, true))
632+
return *Decomp;
633+
return {V, IsKnownNonNegative};
615634
}
616635

617636
// Decompose or as an add if there are no common bits between the operands.
618-
if (match(V, m_DisjointOr(m_Value(Op0), m_ConstantInt(CI))))
619-
return MergeResults(Op0, CI, IsSigned);
637+
if (match(V, m_DisjointOr(m_Value(Op0), m_ConstantInt(CI)))) {
638+
if (auto Decomp = MergeResults(Op0, CI, IsSigned))
639+
return *Decomp;
640+
return {V, IsKnownNonNegative};
641+
}
620642

621643
if (match(V, m_NUWShl(m_Value(Op1), m_ConstantInt(CI))) && canUseSExt(CI)) {
622644
if (CI->getSExtValue() < 0 || CI->getSExtValue() >= 64)
623645
return {V, IsKnownNonNegative};
624646
auto Result = decompose(Op1, Preconditions, IsSigned, DL);
625-
Result.mul(int64_t{1} << CI->getSExtValue());
626-
return Result;
647+
if (!Result.mul(int64_t{1} << CI->getSExtValue()))
648+
return Result;
649+
return {V, IsKnownNonNegative};
627650
}
628651

629652
if (match(V, m_NUWMul(m_Value(Op1), m_ConstantInt(CI))) && canUseSExt(CI) &&
630653
(!CI->isNegative())) {
631654
auto Result = decompose(Op1, Preconditions, IsSigned, DL);
632-
Result.mul(CI->getSExtValue());
633-
return Result;
655+
if (!Result.mul(CI->getSExtValue()))
656+
return Result;
657+
return {V, IsKnownNonNegative};
634658
}
635659

636660
if (match(V, m_NUWSub(m_Value(Op0), m_Value(Op1)))) {
637661
auto ResA = decompose(Op0, Preconditions, IsSigned, DL);
638662
auto ResB = decompose(Op1, Preconditions, IsSigned, DL);
639-
ResA.sub(ResB);
640-
return ResA;
663+
if (!ResA.sub(ResB))
664+
return ResA;
665+
return {V, IsKnownNonNegative};
641666
}
642667

643668
return {V, IsKnownNonNegative};

llvm/test/Transforms/ConstraintElimination/constraint-overflow.ll

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,25 @@ entry:
5252
%c = icmp slt i64 0, %sub
5353
ret i1 %c
5454
}
55+
56+
define i1 @pr140481(i32 %x) {
57+
; CHECK-LABEL: define i1 @pr140481(
58+
; CHECK-SAME: i32 [[X:%.*]]) {
59+
; CHECK-NEXT: entry:
60+
; CHECK-NEXT: [[COND:%.*]] = icmp slt i32 [[X]], 0
61+
; CHECK-NEXT: call void @llvm.assume(i1 [[COND]])
62+
; CHECK-NEXT: [[ADD:%.*]] = add nsw i32 [[X]], 5001000
63+
; CHECK-NEXT: [[MUL1:%.*]] = mul nsw i32 [[ADD]], -5001000
64+
; CHECK-NEXT: [[MUL2:%.*]] = mul nsw i32 [[MUL1]], 5001000
65+
; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i32 [[MUL2]], 0
66+
; CHECK-NEXT: ret i1 [[CMP2]]
67+
;
68+
entry:
69+
%cond = icmp slt i32 %x, 0
70+
call void @llvm.assume(i1 %cond)
71+
%add = add nsw i32 %x, 5001000
72+
%mul1 = mul nsw i32 %add, -5001000
73+
%mul2 = mul nsw i32 %mul1, 5001000
74+
%cmp2 = icmp sgt i32 %mul2, 0
75+
ret i1 %cmp2
76+
}

0 commit comments

Comments
 (0)