Skip to content

[Reassociate] Preserve NUW flags after expr tree rewriting #72360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 9, 2023

Conversation

dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented Nov 15, 2023

@llvmbot
Copy link
Member

llvmbot commented Nov 15, 2023

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

Changes

Alive2: https://alive2.llvm.org/ce/z/38KiC_

This missed optimization is discovered with the help of AliveToolkit/alive2#962.


Full diff: https://github.com/llvm/llvm-project/pull/72360.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Scalar/Reassociate.h (+2-1)
  • (modified) llvm/lib/Transforms/Scalar/Reassociate.cpp (+19-9)
  • (modified) llvm/test/Transforms/Reassociate/local-cse.ll (+20-20)
diff --git a/llvm/include/llvm/Transforms/Scalar/Reassociate.h b/llvm/include/llvm/Transforms/Scalar/Reassociate.h
index 28794d27325adec..7e47f8ae5d81e96 100644
--- a/llvm/include/llvm/Transforms/Scalar/Reassociate.h
+++ b/llvm/include/llvm/Transforms/Scalar/Reassociate.h
@@ -102,7 +102,8 @@ class ReassociatePass : public PassInfoMixin<ReassociatePass> {
   void canonicalizeOperands(Instruction *I);
   void ReassociateExpression(BinaryOperator *I);
   void RewriteExprTree(BinaryOperator *I,
-                       SmallVectorImpl<reassociate::ValueEntry> &Ops);
+                       SmallVectorImpl<reassociate::ValueEntry> &Ops,
+                       bool HasNUW);
   Value *OptimizeExpression(BinaryOperator *I,
                             SmallVectorImpl<reassociate::ValueEntry> &Ops);
   Value *OptimizeAdd(Instruction *I,
diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp
index 9c4a344d4295f8a..07e8f1b24d8c759 100644
--- a/llvm/lib/Transforms/Scalar/Reassociate.cpp
+++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp
@@ -466,7 +466,8 @@ using RepeatedValue = std::pair<Value*, APInt>;
 /// type and thus make the expression bigger.
 static bool LinearizeExprTree(Instruction *I,
                               SmallVectorImpl<RepeatedValue> &Ops,
-                              ReassociatePass::OrderedSet &ToRedo) {
+                              ReassociatePass::OrderedSet &ToRedo,
+                              bool &HasNUW) {
   assert((isa<UnaryOperator>(I) || isa<BinaryOperator>(I)) &&
          "Expected a UnaryOperator or BinaryOperator!");
   LLVM_DEBUG(dbgs() << "LINEARIZE: " << *I << '\n');
@@ -515,6 +516,9 @@ static bool LinearizeExprTree(Instruction *I,
     std::pair<Instruction*, APInt> P = Worklist.pop_back_val();
     I = P.first; // We examine the operands of this binary operator.
 
+    if (isa<OverflowingBinaryOperator>(I))
+      HasNUW &= I->hasNoUnsignedWrap();
+
     for (unsigned OpIdx = 0; OpIdx < I->getNumOperands(); ++OpIdx) { // Visit operands.
       Value *Op = I->getOperand(OpIdx);
       APInt Weight = P.second; // Number of paths to this operand.
@@ -657,7 +661,8 @@ static bool LinearizeExprTree(Instruction *I,
 /// Now that the operands for this expression tree are
 /// linearized and optimized, emit them in-order.
 void ReassociatePass::RewriteExprTree(BinaryOperator *I,
-                                      SmallVectorImpl<ValueEntry> &Ops) {
+                                      SmallVectorImpl<ValueEntry> &Ops,
+                                      bool HasNUW) {
   assert(Ops.size() > 1 && "Single values should be used directly!");
 
   // Since our optimizations should never increase the number of operations, the
@@ -814,14 +819,17 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I,
   if (ExpressionChangedStart) {
     bool ClearFlags = true;
     do {
-      // Preserve FastMathFlags.
+      // Preserve flags.
       if (ClearFlags) {
         if (isa<FPMathOperator>(I)) {
           FastMathFlags Flags = I->getFastMathFlags();
           ExpressionChangedStart->clearSubclassOptionalData();
           ExpressionChangedStart->setFastMathFlags(Flags);
-        } else
+        } else {
           ExpressionChangedStart->clearSubclassOptionalData();
+          if (HasNUW && isa<OverflowingBinaryOperator>(ExpressionChangedStart))
+            ExpressionChangedStart->setHasNoUnsignedWrap();
+        }
       }
 
       if (ExpressionChangedStart == ExpressionChangedEnd)
@@ -1171,7 +1179,8 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
     return nullptr;
 
   SmallVector<RepeatedValue, 8> Tree;
-  MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts);
+  bool HasNUW = true;
+  MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts, HasNUW);
   SmallVector<ValueEntry, 8> Factors;
   Factors.reserve(Tree.size());
   for (unsigned i = 0, e = Tree.size(); i != e; ++i) {
@@ -1213,7 +1222,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
 
   if (!FoundFactor) {
     // Make sure to restore the operands to the expression tree.
-    RewriteExprTree(BO, Factors);
+    RewriteExprTree(BO, Factors, HasNUW);
     return nullptr;
   }
 
@@ -1225,7 +1234,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
     RedoInsts.insert(BO);
     V = Factors[0].Op;
   } else {
-    RewriteExprTree(BO, Factors);
+    RewriteExprTree(BO, Factors, HasNUW);
     V = BO;
   }
 
@@ -2349,7 +2358,8 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
   // First, walk the expression tree, linearizing the tree, collecting the
   // operand information.
   SmallVector<RepeatedValue, 8> Tree;
-  MadeChange |= LinearizeExprTree(I, Tree, RedoInsts);
+  bool HasNUW = true;
+  MadeChange |= LinearizeExprTree(I, Tree, RedoInsts, HasNUW);
   SmallVector<ValueEntry, 8> Ops;
   Ops.reserve(Tree.size());
   for (const RepeatedValue &E : Tree)
@@ -2542,7 +2552,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
              dbgs() << '\n');
   // Now that we ordered and optimized the expressions, splat them back into
   // the expression tree, removing any unneeded nodes.
-  RewriteExprTree(I, Ops);
+  RewriteExprTree(I, Ops, HasNUW);
 }
 
 void
diff --git a/llvm/test/Transforms/Reassociate/local-cse.ll b/llvm/test/Transforms/Reassociate/local-cse.ll
index 1609cb1b36fd93e..4d0467e263f5538 100644
--- a/llvm/test/Transforms/Reassociate/local-cse.ll
+++ b/llvm/test/Transforms/Reassociate/local-cse.ll
@@ -26,16 +26,16 @@ define void @chain_spanning_several_blocks(i64 %inv1, i64 %inv2, i64 %inv3, i64
 ; LOCAL_CSE-LABEL: define void @chain_spanning_several_blocks
 ; LOCAL_CSE-SAME: (i64 [[INV1:%.*]], i64 [[INV2:%.*]], i64 [[INV3:%.*]], i64 [[INV4:%.*]], i64 [[INV5:%.*]]) {
 ; LOCAL_CSE-NEXT:  bb1:
-; LOCAL_CSE-NEXT:    [[CHAIN_A0:%.*]] = add i64 [[INV2]], [[INV1]]
+; LOCAL_CSE-NEXT:    [[CHAIN_A0:%.*]] = add nuw i64 [[INV2]], [[INV1]]
 ; LOCAL_CSE-NEXT:    br label [[BB2:%.*]]
 ; LOCAL_CSE:       bb2:
 ; LOCAL_CSE-NEXT:    [[VAL_BB2:%.*]] = call i64 @get_val()
-; LOCAL_CSE-NEXT:    [[CHAIN_A1:%.*]] = add i64 [[CHAIN_A0]], [[INV4]]
-; LOCAL_CSE-NEXT:    [[CHAIN_A2:%.*]] = add i64 [[CHAIN_A1]], [[VAL_BB2]]
-; LOCAL_CSE-NEXT:    [[CHAIN_B1:%.*]] = add i64 [[CHAIN_A0]], [[INV5]]
-; LOCAL_CSE-NEXT:    [[CHAIN_B2:%.*]] = add i64 [[CHAIN_B1]], [[VAL_BB2]]
-; LOCAL_CSE-NEXT:    [[CHAIN_C0:%.*]] = add i64 [[INV3]], [[INV1]]
-; LOCAL_CSE-NEXT:    [[CHAIN_C1:%.*]] = add i64 [[CHAIN_C0]], [[VAL_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV4]]
+; LOCAL_CSE-NEXT:    [[CHAIN_A2:%.*]] = add nuw i64 [[CHAIN_A1]], [[VAL_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_B1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV5]]
+; LOCAL_CSE-NEXT:    [[CHAIN_B2:%.*]] = add nuw i64 [[CHAIN_B1]], [[VAL_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_C0:%.*]] = add nuw i64 [[INV3]], [[INV1]]
+; LOCAL_CSE-NEXT:    [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_C0]], [[VAL_BB2]]
 ; LOCAL_CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_A2]])
 ; LOCAL_CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_B2]])
 ; LOCAL_CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_C1]])
@@ -47,11 +47,11 @@ define void @chain_spanning_several_blocks(i64 %inv1, i64 %inv2, i64 %inv3, i64
 ; CSE-NEXT:    br label [[BB2:%.*]]
 ; CSE:       bb2:
 ; CSE-NEXT:    [[VAL_BB2:%.*]] = call i64 @get_val()
-; CSE-NEXT:    [[CHAIN_A0:%.*]] = add i64 [[VAL_BB2]], [[INV1]]
-; CSE-NEXT:    [[CHAIN_A1:%.*]] = add i64 [[CHAIN_A0]], [[INV2]]
+; CSE-NEXT:    [[CHAIN_A0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1]]
+; CSE-NEXT:    [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV2]]
 ; CSE-NEXT:    [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV4]]
 ; CSE-NEXT:    [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV5]]
-; CSE-NEXT:    [[CHAIN_C1:%.*]] = add i64 [[CHAIN_A0]], [[INV3]]
+; CSE-NEXT:    [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV3]]
 ; CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_A2]])
 ; CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_B2]])
 ; CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_C1]])
@@ -90,19 +90,19 @@ define void @chain_spanning_several_blocks_no_entry_anchor() {
 ; LOCAL_CSE-NEXT:    br label [[BB1:%.*]]
 ; LOCAL_CSE:       bb1:
 ; LOCAL_CSE-NEXT:    [[INV1_BB1:%.*]] = call i64 @get_val()
-; LOCAL_CSE-NEXT:    [[CHAIN_A0:%.*]] = add i64 [[INV1_BB1]], [[INV2_BB0]]
+; LOCAL_CSE-NEXT:    [[CHAIN_A0:%.*]] = add nuw i64 [[INV1_BB1]], [[INV2_BB0]]
 ; LOCAL_CSE-NEXT:    br label [[BB2:%.*]]
 ; LOCAL_CSE:       bb2:
 ; LOCAL_CSE-NEXT:    [[INV3_BB2:%.*]] = call i64 @get_val()
 ; LOCAL_CSE-NEXT:    [[INV4_BB2:%.*]] = call i64 @get_val()
 ; LOCAL_CSE-NEXT:    [[INV5_BB2:%.*]] = call i64 @get_val()
 ; LOCAL_CSE-NEXT:    [[VAL_BB2:%.*]] = call i64 @get_val()
-; LOCAL_CSE-NEXT:    [[CHAIN_A1:%.*]] = add i64 [[CHAIN_A0]], [[INV4_BB2]]
-; LOCAL_CSE-NEXT:    [[CHAIN_A2:%.*]] = add i64 [[CHAIN_A1]], [[VAL_BB2]]
-; LOCAL_CSE-NEXT:    [[CHAIN_B1:%.*]] = add i64 [[CHAIN_A0]], [[INV5_BB2]]
-; LOCAL_CSE-NEXT:    [[CHAIN_B2:%.*]] = add i64 [[CHAIN_B1]], [[VAL_BB2]]
-; LOCAL_CSE-NEXT:    [[CHAIN_C0:%.*]] = add i64 [[VAL_BB2]], [[INV1_BB1]]
-; LOCAL_CSE-NEXT:    [[CHAIN_C1:%.*]] = add i64 [[CHAIN_C0]], [[INV3_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV4_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_A2:%.*]] = add nuw i64 [[CHAIN_A1]], [[VAL_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_B1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV5_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_B2:%.*]] = add nuw i64 [[CHAIN_B1]], [[VAL_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_C0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1_BB1]]
+; LOCAL_CSE-NEXT:    [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_C0]], [[INV3_BB2]]
 ; LOCAL_CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_A2]])
 ; LOCAL_CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_B2]])
 ; LOCAL_CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_C1]])
@@ -120,11 +120,11 @@ define void @chain_spanning_several_blocks_no_entry_anchor() {
 ; CSE-NEXT:    [[INV4_BB2:%.*]] = call i64 @get_val()
 ; CSE-NEXT:    [[INV5_BB2:%.*]] = call i64 @get_val()
 ; CSE-NEXT:    [[VAL_BB2:%.*]] = call i64 @get_val()
-; CSE-NEXT:    [[CHAIN_A0:%.*]] = add i64 [[VAL_BB2]], [[INV1_BB1]]
-; CSE-NEXT:    [[CHAIN_A1:%.*]] = add i64 [[CHAIN_A0]], [[INV2_BB0]]
+; CSE-NEXT:    [[CHAIN_A0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1_BB1]]
+; CSE-NEXT:    [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV2_BB0]]
 ; CSE-NEXT:    [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV4_BB2]]
 ; CSE-NEXT:    [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV5_BB2]]
-; CSE-NEXT:    [[CHAIN_C1:%.*]] = add i64 [[CHAIN_A0]], [[INV3_BB2]]
+; CSE-NEXT:    [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV3_BB2]]
 ; CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_A2]])
 ; CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_B2]])
 ; CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_C1]])

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks plausible to me, but I'm not familiar with this code.

@dtcxzyw
Copy link
Member Author

dtcxzyw commented Nov 15, 2023

Sorry, it doesn't hold for mul if one of the operands is zero.
Alive2: https://alive2.llvm.org/ce/z/686-w6

@dtcxzyw
Copy link
Member Author

dtcxzyw commented Dec 8, 2023

Ping.

@dtcxzyw dtcxzyw requested review from preames and topperc December 8, 2023 06:10
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dtcxzyw dtcxzyw merged commit 312cb34 into llvm:main Dec 9, 2023
@dtcxzyw dtcxzyw deleted the reassociate-nuw branch December 9, 2023 08:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants