Skip to content

Commit 312cb34

Browse files
authored
[Reassociate] Preserve NUW flags after expr tree rewriting (#72360)
Alive2: https://alive2.llvm.org/ce/z/38KiC_
1 parent 57eb205 commit 312cb34

File tree

4 files changed

+78
-30
lines changed

4 files changed

+78
-30
lines changed

llvm/include/llvm/Transforms/Scalar/Reassociate.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ class ReassociatePass : public PassInfoMixin<ReassociatePass> {
102102
void canonicalizeOperands(Instruction *I);
103103
void ReassociateExpression(BinaryOperator *I);
104104
void RewriteExprTree(BinaryOperator *I,
105-
SmallVectorImpl<reassociate::ValueEntry> &Ops);
105+
SmallVectorImpl<reassociate::ValueEntry> &Ops,
106+
bool HasNUW);
106107
Value *OptimizeExpression(BinaryOperator *I,
107108
SmallVectorImpl<reassociate::ValueEntry> &Ops);
108109
Value *OptimizeAdd(Instruction *I,

llvm/lib/Transforms/Scalar/Reassociate.cpp

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,8 @@ using RepeatedValue = std::pair<Value*, APInt>;
466466
/// type and thus make the expression bigger.
467467
static bool LinearizeExprTree(Instruction *I,
468468
SmallVectorImpl<RepeatedValue> &Ops,
469-
ReassociatePass::OrderedSet &ToRedo) {
469+
ReassociatePass::OrderedSet &ToRedo,
470+
bool &HasNUW) {
470471
assert((isa<UnaryOperator>(I) || isa<BinaryOperator>(I)) &&
471472
"Expected a UnaryOperator or BinaryOperator!");
472473
LLVM_DEBUG(dbgs() << "LINEARIZE: " << *I << '\n');
@@ -515,6 +516,9 @@ static bool LinearizeExprTree(Instruction *I,
515516
std::pair<Instruction*, APInt> P = Worklist.pop_back_val();
516517
I = P.first; // We examine the operands of this binary operator.
517518

519+
if (isa<OverflowingBinaryOperator>(I))
520+
HasNUW &= I->hasNoUnsignedWrap();
521+
518522
for (unsigned OpIdx = 0; OpIdx < I->getNumOperands(); ++OpIdx) { // Visit operands.
519523
Value *Op = I->getOperand(OpIdx);
520524
APInt Weight = P.second; // Number of paths to this operand.
@@ -657,7 +661,8 @@ static bool LinearizeExprTree(Instruction *I,
657661
/// Now that the operands for this expression tree are
658662
/// linearized and optimized, emit them in-order.
659663
void ReassociatePass::RewriteExprTree(BinaryOperator *I,
660-
SmallVectorImpl<ValueEntry> &Ops) {
664+
SmallVectorImpl<ValueEntry> &Ops,
665+
bool HasNUW) {
661666
assert(Ops.size() > 1 && "Single values should be used directly!");
662667

663668
// Since our optimizations should never increase the number of operations, the
@@ -814,14 +819,20 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I,
814819
if (ExpressionChangedStart) {
815820
bool ClearFlags = true;
816821
do {
817-
// Preserve FastMathFlags.
822+
// Preserve flags.
818823
if (ClearFlags) {
819824
if (isa<FPMathOperator>(I)) {
820825
FastMathFlags Flags = I->getFastMathFlags();
821826
ExpressionChangedStart->clearSubclassOptionalData();
822827
ExpressionChangedStart->setFastMathFlags(Flags);
823-
} else
828+
} else {
824829
ExpressionChangedStart->clearSubclassOptionalData();
830+
// Note that it doesn't hold for mul if one of the operands is zero.
831+
// TODO: We can preserve NUW flag if we prove that all mul operands
832+
// are non-zero.
833+
if (HasNUW && ExpressionChangedStart->getOpcode() == Instruction::Add)
834+
ExpressionChangedStart->setHasNoUnsignedWrap();
835+
}
825836
}
826837

827838
if (ExpressionChangedStart == ExpressionChangedEnd)
@@ -1175,7 +1186,8 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
11751186
return nullptr;
11761187

11771188
SmallVector<RepeatedValue, 8> Tree;
1178-
MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts);
1189+
bool HasNUW = true;
1190+
MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts, HasNUW);
11791191
SmallVector<ValueEntry, 8> Factors;
11801192
Factors.reserve(Tree.size());
11811193
for (unsigned i = 0, e = Tree.size(); i != e; ++i) {
@@ -1217,7 +1229,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
12171229

12181230
if (!FoundFactor) {
12191231
// Make sure to restore the operands to the expression tree.
1220-
RewriteExprTree(BO, Factors);
1232+
RewriteExprTree(BO, Factors, HasNUW);
12211233
return nullptr;
12221234
}
12231235

@@ -1229,7 +1241,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
12291241
RedoInsts.insert(BO);
12301242
V = Factors[0].Op;
12311243
} else {
1232-
RewriteExprTree(BO, Factors);
1244+
RewriteExprTree(BO, Factors, HasNUW);
12331245
V = BO;
12341246
}
12351247

@@ -2354,7 +2366,8 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
23542366
// First, walk the expression tree, linearizing the tree, collecting the
23552367
// operand information.
23562368
SmallVector<RepeatedValue, 8> Tree;
2357-
MadeChange |= LinearizeExprTree(I, Tree, RedoInsts);
2369+
bool HasNUW = true;
2370+
MadeChange |= LinearizeExprTree(I, Tree, RedoInsts, HasNUW);
23582371
SmallVector<ValueEntry, 8> Ops;
23592372
Ops.reserve(Tree.size());
23602373
for (const RepeatedValue &E : Tree)
@@ -2547,7 +2560,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
25472560
dbgs() << '\n');
25482561
// Now that we ordered and optimized the expressions, splat them back into
25492562
// the expression tree, removing any unneeded nodes.
2550-
RewriteExprTree(I, Ops);
2563+
RewriteExprTree(I, Ops, HasNUW);
25512564
}
25522565

25532566
void

llvm/test/Transforms/Reassociate/local-cse.ll

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,16 @@ define void @chain_spanning_several_blocks(i64 %inv1, i64 %inv2, i64 %inv3, i64
2626
; LOCAL_CSE-LABEL: define void @chain_spanning_several_blocks
2727
; LOCAL_CSE-SAME: (i64 [[INV1:%.*]], i64 [[INV2:%.*]], i64 [[INV3:%.*]], i64 [[INV4:%.*]], i64 [[INV5:%.*]]) {
2828
; LOCAL_CSE-NEXT: bb1:
29-
; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add i64 [[INV2]], [[INV1]]
29+
; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[INV2]], [[INV1]]
3030
; LOCAL_CSE-NEXT: br label [[BB2:%.*]]
3131
; LOCAL_CSE: bb2:
3232
; LOCAL_CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val()
33-
; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add i64 [[CHAIN_A0]], [[INV4]]
34-
; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add i64 [[CHAIN_A1]], [[VAL_BB2]]
35-
; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add i64 [[CHAIN_A0]], [[INV5]]
36-
; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add i64 [[CHAIN_B1]], [[VAL_BB2]]
37-
; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add i64 [[INV3]], [[INV1]]
38-
; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add i64 [[CHAIN_C0]], [[VAL_BB2]]
33+
; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV4]]
34+
; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw i64 [[CHAIN_A1]], [[VAL_BB2]]
35+
; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV5]]
36+
; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw i64 [[CHAIN_B1]], [[VAL_BB2]]
37+
; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add nuw i64 [[INV3]], [[INV1]]
38+
; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_C0]], [[VAL_BB2]]
3939
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]])
4040
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]])
4141
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]])
@@ -47,11 +47,11 @@ define void @chain_spanning_several_blocks(i64 %inv1, i64 %inv2, i64 %inv3, i64
4747
; CSE-NEXT: br label [[BB2:%.*]]
4848
; CSE: bb2:
4949
; CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val()
50-
; CSE-NEXT: [[CHAIN_A0:%.*]] = add i64 [[VAL_BB2]], [[INV1]]
51-
; CSE-NEXT: [[CHAIN_A1:%.*]] = add i64 [[CHAIN_A0]], [[INV2]]
50+
; CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1]]
51+
; CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV2]]
5252
; CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV4]]
5353
; CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV5]]
54-
; CSE-NEXT: [[CHAIN_C1:%.*]] = add i64 [[CHAIN_A0]], [[INV3]]
54+
; CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV3]]
5555
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]])
5656
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]])
5757
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]])
@@ -90,19 +90,19 @@ define void @chain_spanning_several_blocks_no_entry_anchor() {
9090
; LOCAL_CSE-NEXT: br label [[BB1:%.*]]
9191
; LOCAL_CSE: bb1:
9292
; LOCAL_CSE-NEXT: [[INV1_BB1:%.*]] = call i64 @get_val()
93-
; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add i64 [[INV1_BB1]], [[INV2_BB0]]
93+
; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[INV1_BB1]], [[INV2_BB0]]
9494
; LOCAL_CSE-NEXT: br label [[BB2:%.*]]
9595
; LOCAL_CSE: bb2:
9696
; LOCAL_CSE-NEXT: [[INV3_BB2:%.*]] = call i64 @get_val()
9797
; LOCAL_CSE-NEXT: [[INV4_BB2:%.*]] = call i64 @get_val()
9898
; LOCAL_CSE-NEXT: [[INV5_BB2:%.*]] = call i64 @get_val()
9999
; LOCAL_CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val()
100-
; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add i64 [[CHAIN_A0]], [[INV4_BB2]]
101-
; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add i64 [[CHAIN_A1]], [[VAL_BB2]]
102-
; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add i64 [[CHAIN_A0]], [[INV5_BB2]]
103-
; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add i64 [[CHAIN_B1]], [[VAL_BB2]]
104-
; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add i64 [[VAL_BB2]], [[INV1_BB1]]
105-
; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add i64 [[CHAIN_C0]], [[INV3_BB2]]
100+
; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV4_BB2]]
101+
; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw i64 [[CHAIN_A1]], [[VAL_BB2]]
102+
; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV5_BB2]]
103+
; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw i64 [[CHAIN_B1]], [[VAL_BB2]]
104+
; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1_BB1]]
105+
; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_C0]], [[INV3_BB2]]
106106
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]])
107107
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]])
108108
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]])
@@ -120,11 +120,11 @@ define void @chain_spanning_several_blocks_no_entry_anchor() {
120120
; CSE-NEXT: [[INV4_BB2:%.*]] = call i64 @get_val()
121121
; CSE-NEXT: [[INV5_BB2:%.*]] = call i64 @get_val()
122122
; CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val()
123-
; CSE-NEXT: [[CHAIN_A0:%.*]] = add i64 [[VAL_BB2]], [[INV1_BB1]]
124-
; CSE-NEXT: [[CHAIN_A1:%.*]] = add i64 [[CHAIN_A0]], [[INV2_BB0]]
123+
; CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1_BB1]]
124+
; CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV2_BB0]]
125125
; CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV4_BB2]]
126126
; CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV5_BB2]]
127-
; CSE-NEXT: [[CHAIN_C1:%.*]] = add i64 [[CHAIN_A0]], [[INV3_BB2]]
127+
; CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV3_BB2]]
128128
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]])
129129
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]])
130130
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]])
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
2+
; RUN: opt < %s -passes=reassociate -S | FileCheck %s
3+
4+
; We cannot preserve nuw flags for mul
5+
define i4 @nuw_preserve_negative(i4 %a, i4 %b, i4 %c) {
6+
; CHECK-LABEL: define i4 @nuw_preserve_negative(
7+
; CHECK-SAME: i4 [[A:%.*]], i4 [[B:%.*]], i4 [[C:%.*]]) {
8+
; CHECK-NEXT: [[V0:%.*]] = mul i4 [[B]], [[A]]
9+
; CHECK-NEXT: [[V1:%.*]] = mul i4 [[V0]], [[C]]
10+
; CHECK-NEXT: ret i4 [[V1]]
11+
;
12+
%v0 = mul nuw i4 %a, %c
13+
%v1 = mul nuw i4 %v0, %b
14+
ret i4 %v1
15+
}
16+
17+
; TODO: we can add nuw flags if we know all operands are non-zero.
18+
define i4 @nuw_preserve_non_zero(i4 %a, i4 %b, i4 %c) {
19+
; CHECK-LABEL: define i4 @nuw_preserve_non_zero(
20+
; CHECK-SAME: i4 [[A:%.*]], i4 [[B:%.*]], i4 [[C:%.*]]) {
21+
; CHECK-NEXT: [[A0:%.*]] = add nuw i4 [[A]], 1
22+
; CHECK-NEXT: [[B0:%.*]] = add nuw i4 [[B]], 1
23+
; CHECK-NEXT: [[C0:%.*]] = add nuw i4 [[C]], 1
24+
; CHECK-NEXT: [[V0:%.*]] = mul i4 [[B0]], [[A0]]
25+
; CHECK-NEXT: [[V1:%.*]] = mul i4 [[V0]], [[C0]]
26+
; CHECK-NEXT: ret i4 [[V1]]
27+
;
28+
%a0 = add nuw i4 %a, 1
29+
%b0 = add nuw i4 %b, 1
30+
%c0 = add nuw i4 %c, 1
31+
%v0 = mul nuw i4 %a0, %c0
32+
%v1 = mul nuw i4 %v0, %b0
33+
ret i4 %v1
34+
}

0 commit comments

Comments
 (0)