Skip to content

[InstCombine] Simplify nonnull pointers #128111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 22, 2025
Merged

Conversation

dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented Feb 21, 2025

This patch is the follow-up of #127979. It introduces a helper simplifyNonNullOperand to avoid duplicate logic. It also addresses the one-use issue in visitLoadInst, as discussed in #127979 (comment).
The nonnull attribute is also supported. Proof: https://alive2.llvm.org/ce/z/MCKgT9

Based on #128107.

@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2025

@llvm/pr-subscribers-llvm-ir

Author: Yingwei Zheng (dtcxzyw)

Changes

This patch is the follow-up of #127979. It introduces a helper simplifyNonNullOperand to avoid duplicate logic. It also addresses the one-use issue in visitLoadInst, as discussed in #127979 (comment).
The nonnull attribute is also supported. Proof: https://alive2.llvm.org/ce/z/MCKgT9

Based on #128107.


Full diff: https://github.com/llvm/llvm-project/pull/128111.diff

9 Files Affected:

  • (modified) llvm/include/llvm/IR/Function.h (+5)
  • (modified) llvm/lib/IR/Function.cpp (+11)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+14-4)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+4)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp (+21-25)
  • (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+9-1)
  • (modified) llvm/test/Transforms/InstCombine/nonnull-select.ll (+9-20)
  • (modified) llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll (+2-5)
  • (added) llvm/test/Transforms/PhaseOrdering/memset-combine.ll (+20)
diff --git a/llvm/include/llvm/IR/Function.h b/llvm/include/llvm/IR/Function.h
index 29041688124bc..7ea8673bedad1 100644
--- a/llvm/include/llvm/IR/Function.h
+++ b/llvm/include/llvm/IR/Function.h
@@ -731,6 +731,11 @@ class LLVM_ABI Function : public GlobalObject, public ilist_node<Function> {
   /// create a Function) from the Function Src to this one.
   void copyAttributesFrom(const Function *Src);
 
+  /// Return true if the return value is known to be not null.
+  /// This may be because it has the nonnull attribute, or because at least
+  /// one byte is dereferenceable and the pointer is in addrspace(0).
+  bool isReturnNonNull() const;
+
   /// deleteBody - This method deletes the body of the function, and converts
   /// the linkage to external.
   ///
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index 5666f0a53866f..d22cf65769e26 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -873,6 +873,17 @@ void Function::copyAttributesFrom(const Function *Src) {
     setPrologueData(Src->getPrologueData());
 }
 
+bool Function::isReturnNonNull() const {
+  if (hasRetAttribute(Attribute::NonNull))
+    return true;
+
+  if (AttributeSets.getRetDereferenceableBytes() > 0 &&
+      !NullPointerIsDefined(this, getReturnType()->getPointerAddressSpace()))
+    return true;
+
+  return false;
+}
+
 MemoryEffects Function::getMemoryEffects() const {
   return getAttributes().getMemoryEffects();
 }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 400ebcf493713..c8b3d29c3aa98 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3993,10 +3993,20 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) {
   unsigned ArgNo = 0;
 
   for (Value *V : Call.args()) {
-    if (V->getType()->isPointerTy() &&
-        !Call.paramHasAttr(ArgNo, Attribute::NonNull) &&
-        isKnownNonZero(V, getSimplifyQuery().getWithInstruction(&Call)))
-      ArgNos.push_back(ArgNo);
+    if (V->getType()->isPointerTy()) {
+      // Simplify the nonnull operand before nonnull inference to avoid
+      // unnecessary queries.
+      if (Call.paramHasNonNullAttr(ArgNo, /*AllowUndefOrPoison=*/true)) {
+        if (Value *Res = simplifyNonNullOperand(V)) {
+          replaceOperand(Call, ArgNo, Res);
+          Changed = true;
+        }
+      }
+
+      if (!Call.paramHasAttr(ArgNo, Attribute::NonNull) &&
+          isKnownNonZero(V, getSimplifyQuery().getWithInstruction(&Call)))
+        ArgNos.push_back(ArgNo);
+    }
     ArgNo++;
   }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 83e1da98deeda..71c80d4c401f8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -455,6 +455,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
 
   Instruction *hoistFNegAboveFMulFDiv(Value *FNegOp, Instruction &FMFSource);
 
+  /// Simplify \p V given that it is known to be non-null.
+  /// Returns the simplified value if possible, otherwise returns nullptr.
+  Value *simplifyNonNullOperand(Value *V);
+
 public:
   /// Create and insert the idiom we use to indicate a block is unreachable
   /// without having to rewrite the CFG from within InstCombine.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index d5534c15cca76..89fc1051b18dc 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -982,6 +982,19 @@ static bool canSimplifyNullLoadOrGEP(LoadInst &LI, Value *Op) {
   return false;
 }
 
+/// TODO: Recursively simplify nonnull value to handle one-use inbounds GEPs.
+Value *InstCombinerImpl::simplifyNonNullOperand(Value *V) {
+  if (auto *Sel = dyn_cast<SelectInst>(V)) {
+    if (isa<ConstantPointerNull>(Sel->getOperand(1)))
+      return Sel->getOperand(2);
+
+    if (isa<ConstantPointerNull>(Sel->getOperand(2)))
+      return Sel->getOperand(1);
+  }
+
+  return nullptr;
+}
+
 Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) {
   Value *Op = LI.getOperand(0);
   if (Value *Res = simplifyLoadInst(&LI, Op, SQ.getWithInstruction(&LI)))
@@ -1059,20 +1072,13 @@ Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) {
         V2->copyMetadata(LI, Metadata::PoisonGeneratingIDs);
         return SelectInst::Create(SI->getCondition(), V1, V2);
       }
-
-      // load (select (cond, null, P)) -> load P
-      if (isa<ConstantPointerNull>(SI->getOperand(1)) &&
-          !NullPointerIsDefined(SI->getFunction(),
-                                LI.getPointerAddressSpace()))
-        return replaceOperand(LI, 0, SI->getOperand(2));
-
-      // load (select (cond, P, null)) -> load P
-      if (isa<ConstantPointerNull>(SI->getOperand(2)) &&
-          !NullPointerIsDefined(SI->getFunction(),
-                                LI.getPointerAddressSpace()))
-        return replaceOperand(LI, 0, SI->getOperand(1));
     }
   }
+
+  if (!NullPointerIsDefined(LI.getFunction(), LI.getPointerAddressSpace()))
+    if (Value *V = simplifyNonNullOperand(Op))
+      return replaceOperand(LI, 0, V);
+
   return nullptr;
 }
 
@@ -1437,19 +1443,9 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) {
   if (isa<UndefValue>(Val))
     return eraseInstFromFunction(SI);
 
-  // TODO: Add a helper to simplify the pointer operand for all memory
-  // instructions.
-  // store val, (select (cond, null, P)) -> store val, P
-  // store val, (select (cond, P, null)) -> store val, P
-  if (!NullPointerIsDefined(SI.getFunction(), SI.getPointerAddressSpace())) {
-    if (SelectInst *Sel = dyn_cast<SelectInst>(Ptr)) {
-      if (isa<ConstantPointerNull>(Sel->getOperand(1)))
-        return replaceOperand(SI, 1, Sel->getOperand(2));
-
-      if (isa<ConstantPointerNull>(Sel->getOperand(2)))
-        return replaceOperand(SI, 1, Sel->getOperand(1));
-    }
-  }
+  if (!NullPointerIsDefined(SI.getFunction(), SI.getPointerAddressSpace()))
+    if (Value *V = simplifyNonNullOperand(Ptr))
+      return replaceOperand(SI, 1, V);
 
   return nullptr;
 }
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 5621511570b58..d3af06f63fcd2 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -3587,7 +3587,15 @@ Instruction *InstCombinerImpl::visitFree(CallInst &FI, Value *Op) {
 
 Instruction *InstCombinerImpl::visitReturnInst(ReturnInst &RI) {
   Value *RetVal = RI.getReturnValue();
-  if (!RetVal || !AttributeFuncs::isNoFPClassCompatibleType(RetVal->getType()))
+  if (!RetVal)
+    return nullptr;
+
+  if (RetVal->getType()->isPointerTy() && RI.getFunction()->isReturnNonNull()) {
+    if (Value *V = simplifyNonNullOperand(RetVal))
+      return replaceOperand(RI, 0, V);
+  }
+
+  if (!AttributeFuncs::isNoFPClassCompatibleType(RetVal->getType()))
     return nullptr;
 
   Function *F = RI.getFunction();
diff --git a/llvm/test/Transforms/InstCombine/nonnull-select.ll b/llvm/test/Transforms/InstCombine/nonnull-select.ll
index 3fab2dfb41a42..cc000b4c88164 100644
--- a/llvm/test/Transforms/InstCombine/nonnull-select.ll
+++ b/llvm/test/Transforms/InstCombine/nonnull-select.ll
@@ -5,10 +5,7 @@
 
 define nonnull ptr @pr48975(ptr %.0) {
 ; CHECK-LABEL: @pr48975(
-; CHECK-NEXT:    [[DOT1:%.*]] = load ptr, ptr [[DOT0:%.*]], align 8
-; CHECK-NEXT:    [[DOT2:%.*]] = icmp eq ptr [[DOT1]], null
-; CHECK-NEXT:    [[DOT4:%.*]] = select i1 [[DOT2]], ptr null, ptr [[DOT0]]
-; CHECK-NEXT:    ret ptr [[DOT4]]
+; CHECK-NEXT:    ret ptr [[DOT4:%.*]]
 ;
   %.1 = load ptr, ptr %.0, align 8
   %.2 = icmp eq ptr %.1, null
@@ -18,8 +15,7 @@ define nonnull ptr @pr48975(ptr %.0) {
 
 define nonnull ptr @nonnull_ret(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_ret(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr [[P:%.*]], ptr null
-; CHECK-NEXT:    ret ptr [[RES]]
+; CHECK-NEXT:    ret ptr [[RES:%.*]]
 ;
   %res = select i1 %cond, ptr %p, ptr null
   ret ptr %res
@@ -27,8 +23,7 @@ define nonnull ptr @nonnull_ret(i1 %cond, ptr %p) {
 
 define nonnull ptr @nonnull_ret2(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_ret2(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
-; CHECK-NEXT:    ret ptr [[RES]]
+; CHECK-NEXT:    ret ptr [[RES:%.*]]
 ;
   %res = select i1 %cond, ptr null, ptr %p
   ret ptr %res
@@ -36,8 +31,7 @@ define nonnull ptr @nonnull_ret2(i1 %cond, ptr %p) {
 
 define nonnull noundef ptr @nonnull_noundef_ret(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_noundef_ret(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr [[P:%.*]], ptr null
-; CHECK-NEXT:    ret ptr [[RES]]
+; CHECK-NEXT:    ret ptr [[RES:%.*]]
 ;
   %res = select i1 %cond, ptr %p, ptr null
   ret ptr %res
@@ -45,8 +39,7 @@ define nonnull noundef ptr @nonnull_noundef_ret(i1 %cond, ptr %p) {
 
 define nonnull noundef ptr @nonnull_noundef_ret2(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_noundef_ret2(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
-; CHECK-NEXT:    ret ptr [[RES]]
+; CHECK-NEXT:    ret ptr [[RES:%.*]]
 ;
   %res = select i1 %cond, ptr null, ptr %p
   ret ptr %res
@@ -55,8 +48,7 @@ define nonnull noundef ptr @nonnull_noundef_ret2(i1 %cond, ptr %p) {
 
 define void @nonnull_call(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_call(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr [[P:%.*]], ptr null
-; CHECK-NEXT:    call void @f(ptr nonnull [[RES]])
+; CHECK-NEXT:    call void @f(ptr nonnull [[RES:%.*]])
 ; CHECK-NEXT:    ret void
 ;
   %res = select i1 %cond, ptr %p, ptr null
@@ -66,8 +58,7 @@ define void @nonnull_call(i1 %cond, ptr %p) {
 
 define void @nonnull_call2(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_call2(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
-; CHECK-NEXT:    call void @f(ptr nonnull [[RES]])
+; CHECK-NEXT:    call void @f(ptr nonnull [[RES:%.*]])
 ; CHECK-NEXT:    ret void
 ;
   %res = select i1 %cond, ptr null, ptr %p
@@ -77,8 +68,7 @@ define void @nonnull_call2(i1 %cond, ptr %p) {
 
 define void @nonnull_noundef_call(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_noundef_call(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr [[P:%.*]], ptr null
-; CHECK-NEXT:    call void @f(ptr noundef nonnull [[RES]])
+; CHECK-NEXT:    call void @f(ptr noundef nonnull [[RES:%.*]])
 ; CHECK-NEXT:    ret void
 ;
   %res = select i1 %cond, ptr %p, ptr null
@@ -88,8 +78,7 @@ define void @nonnull_noundef_call(i1 %cond, ptr %p) {
 
 define void @nonnull_noundef_call2(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_noundef_call2(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
-; CHECK-NEXT:    call void @f(ptr noundef nonnull [[RES]])
+; CHECK-NEXT:    call void @f(ptr noundef nonnull [[RES:%.*]])
 ; CHECK-NEXT:    ret void
 ;
   %res = select i1 %cond, ptr null, ptr %p
diff --git a/llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll b/llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll
index d8ef0723cf09e..f6bf57a678786 100644
--- a/llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll
+++ b/llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll
@@ -1,24 +1,21 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
 ; RUN: opt -passes='instcombine,early-cse<memssa>' -S %s | FileCheck %s
 
-; FIXME: We can remove the store instruction in the exit block
 define i32 @load_store_sameval(ptr %p, i1 %cond1, i1 %cond2) {
 ; CHECK-LABEL: define i32 @load_store_sameval(
 ; CHECK-SAME: ptr [[P:%.*]], i1 [[COND1:%.*]], i1 [[COND2:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    [[SPEC_SELECT:%.*]] = select i1 [[COND1]], ptr null, ptr [[P]]
-; CHECK-NEXT:    [[PRE:%.*]] = load i32, ptr [[SPEC_SELECT]], align 4
+; CHECK-NEXT:    [[PRE:%.*]] = load i32, ptr [[P]], align 4
 ; CHECK-NEXT:    br label %[[BLOCK:.*]]
 ; CHECK:       [[BLOCK]]:
 ; CHECK-NEXT:    br label %[[BLOCK2:.*]]
 ; CHECK:       [[BLOCK2]]:
 ; CHECK-NEXT:    br i1 [[COND2]], label %[[BLOCK3:.*]], label %[[EXIT:.*]]
 ; CHECK:       [[BLOCK3]]:
-; CHECK-NEXT:    [[LOAD:%.*]] = load double, ptr [[SPEC_SELECT]], align 8
+; CHECK-NEXT:    [[LOAD:%.*]] = load double, ptr [[P]], align 8
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp une double [[LOAD]], 0.000000e+00
 ; CHECK-NEXT:    br i1 [[CMP]], label %[[BLOCK]], label %[[BLOCK2]]
 ; CHECK:       [[EXIT]]:
-; CHECK-NEXT:    store i32 [[PRE]], ptr [[P]], align 4
 ; CHECK-NEXT:    ret i32 0
 ;
 entry:
diff --git a/llvm/test/Transforms/PhaseOrdering/memset-combine.ll b/llvm/test/Transforms/PhaseOrdering/memset-combine.ll
new file mode 100644
index 0000000000000..d1de11258ed91
--- /dev/null
+++ b/llvm/test/Transforms/PhaseOrdering/memset-combine.ll
@@ -0,0 +1,20 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+
+; RUN: opt < %s -passes=instcombine,memcpyopt -S | FileCheck %s
+
+; FIXME: These two memset calls should be merged into a single one.
+define void @merge_memset(ptr %p, i1 %cond) {
+; CHECK-LABEL: define void @merge_memset(
+; CHECK-SAME: ptr [[P:%.*]], i1 [[COND:%.*]]) {
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[COND]], ptr null, ptr [[P]]
+; CHECK-NEXT:    tail call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(4096) [[P]], i8 0, i64 4096, i1 false)
+; CHECK-NEXT:    [[OFF:%.*]] = getelementptr inbounds nuw i8, ptr [[SEL]], i64 4096
+; CHECK-NEXT:    tail call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(768) [[OFF]], i8 0, i64 768, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %sel = select i1 %cond, ptr null, ptr %p
+  tail call void @llvm.memset.p0.i64(ptr noundef nonnull %sel, i8 0, i64 4096, i1 false)
+  %off = getelementptr inbounds nuw i8, ptr %sel, i64 4096
+  tail call void @llvm.memset.p0.i64(ptr noundef nonnull %off, i8 0, i64 768, i1 false)
+  ret void
+}

@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

Changes

This patch is the follow-up of #127979. It introduces a helper simplifyNonNullOperand to avoid duplicate logic. It also addresses the one-use issue in visitLoadInst, as discussed in #127979 (comment).
The nonnull attribute is also supported. Proof: https://alive2.llvm.org/ce/z/MCKgT9

Based on #128107.


Full diff: https://github.com/llvm/llvm-project/pull/128111.diff

9 Files Affected:

  • (modified) llvm/include/llvm/IR/Function.h (+5)
  • (modified) llvm/lib/IR/Function.cpp (+11)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+14-4)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+4)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp (+21-25)
  • (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+9-1)
  • (modified) llvm/test/Transforms/InstCombine/nonnull-select.ll (+9-20)
  • (modified) llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll (+2-5)
  • (added) llvm/test/Transforms/PhaseOrdering/memset-combine.ll (+20)
diff --git a/llvm/include/llvm/IR/Function.h b/llvm/include/llvm/IR/Function.h
index 29041688124bc..7ea8673bedad1 100644
--- a/llvm/include/llvm/IR/Function.h
+++ b/llvm/include/llvm/IR/Function.h
@@ -731,6 +731,11 @@ class LLVM_ABI Function : public GlobalObject, public ilist_node<Function> {
   /// create a Function) from the Function Src to this one.
   void copyAttributesFrom(const Function *Src);
 
+  /// Return true if the return value is known to be not null.
+  /// This may be because it has the nonnull attribute, or because at least
+  /// one byte is dereferenceable and the pointer is in addrspace(0).
+  bool isReturnNonNull() const;
+
   /// deleteBody - This method deletes the body of the function, and converts
   /// the linkage to external.
   ///
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index 5666f0a53866f..d22cf65769e26 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -873,6 +873,17 @@ void Function::copyAttributesFrom(const Function *Src) {
     setPrologueData(Src->getPrologueData());
 }
 
+bool Function::isReturnNonNull() const {
+  if (hasRetAttribute(Attribute::NonNull))
+    return true;
+
+  if (AttributeSets.getRetDereferenceableBytes() > 0 &&
+      !NullPointerIsDefined(this, getReturnType()->getPointerAddressSpace()))
+    return true;
+
+  return false;
+}
+
 MemoryEffects Function::getMemoryEffects() const {
   return getAttributes().getMemoryEffects();
 }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 400ebcf493713..c8b3d29c3aa98 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3993,10 +3993,20 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) {
   unsigned ArgNo = 0;
 
   for (Value *V : Call.args()) {
-    if (V->getType()->isPointerTy() &&
-        !Call.paramHasAttr(ArgNo, Attribute::NonNull) &&
-        isKnownNonZero(V, getSimplifyQuery().getWithInstruction(&Call)))
-      ArgNos.push_back(ArgNo);
+    if (V->getType()->isPointerTy()) {
+      // Simplify the nonnull operand before nonnull inference to avoid
+      // unnecessary queries.
+      if (Call.paramHasNonNullAttr(ArgNo, /*AllowUndefOrPoison=*/true)) {
+        if (Value *Res = simplifyNonNullOperand(V)) {
+          replaceOperand(Call, ArgNo, Res);
+          Changed = true;
+        }
+      }
+
+      if (!Call.paramHasAttr(ArgNo, Attribute::NonNull) &&
+          isKnownNonZero(V, getSimplifyQuery().getWithInstruction(&Call)))
+        ArgNos.push_back(ArgNo);
+    }
     ArgNo++;
   }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 83e1da98deeda..71c80d4c401f8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -455,6 +455,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
 
   Instruction *hoistFNegAboveFMulFDiv(Value *FNegOp, Instruction &FMFSource);
 
+  /// Simplify \p V given that it is known to be non-null.
+  /// Returns the simplified value if possible, otherwise returns nullptr.
+  Value *simplifyNonNullOperand(Value *V);
+
 public:
   /// Create and insert the idiom we use to indicate a block is unreachable
   /// without having to rewrite the CFG from within InstCombine.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index d5534c15cca76..89fc1051b18dc 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -982,6 +982,19 @@ static bool canSimplifyNullLoadOrGEP(LoadInst &LI, Value *Op) {
   return false;
 }
 
+/// TODO: Recursively simplify nonnull value to handle one-use inbounds GEPs.
+Value *InstCombinerImpl::simplifyNonNullOperand(Value *V) {
+  if (auto *Sel = dyn_cast<SelectInst>(V)) {
+    if (isa<ConstantPointerNull>(Sel->getOperand(1)))
+      return Sel->getOperand(2);
+
+    if (isa<ConstantPointerNull>(Sel->getOperand(2)))
+      return Sel->getOperand(1);
+  }
+
+  return nullptr;
+}
+
 Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) {
   Value *Op = LI.getOperand(0);
   if (Value *Res = simplifyLoadInst(&LI, Op, SQ.getWithInstruction(&LI)))
@@ -1059,20 +1072,13 @@ Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) {
         V2->copyMetadata(LI, Metadata::PoisonGeneratingIDs);
         return SelectInst::Create(SI->getCondition(), V1, V2);
       }
-
-      // load (select (cond, null, P)) -> load P
-      if (isa<ConstantPointerNull>(SI->getOperand(1)) &&
-          !NullPointerIsDefined(SI->getFunction(),
-                                LI.getPointerAddressSpace()))
-        return replaceOperand(LI, 0, SI->getOperand(2));
-
-      // load (select (cond, P, null)) -> load P
-      if (isa<ConstantPointerNull>(SI->getOperand(2)) &&
-          !NullPointerIsDefined(SI->getFunction(),
-                                LI.getPointerAddressSpace()))
-        return replaceOperand(LI, 0, SI->getOperand(1));
     }
   }
+
+  if (!NullPointerIsDefined(LI.getFunction(), LI.getPointerAddressSpace()))
+    if (Value *V = simplifyNonNullOperand(Op))
+      return replaceOperand(LI, 0, V);
+
   return nullptr;
 }
 
@@ -1437,19 +1443,9 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) {
   if (isa<UndefValue>(Val))
     return eraseInstFromFunction(SI);
 
-  // TODO: Add a helper to simplify the pointer operand for all memory
-  // instructions.
-  // store val, (select (cond, null, P)) -> store val, P
-  // store val, (select (cond, P, null)) -> store val, P
-  if (!NullPointerIsDefined(SI.getFunction(), SI.getPointerAddressSpace())) {
-    if (SelectInst *Sel = dyn_cast<SelectInst>(Ptr)) {
-      if (isa<ConstantPointerNull>(Sel->getOperand(1)))
-        return replaceOperand(SI, 1, Sel->getOperand(2));
-
-      if (isa<ConstantPointerNull>(Sel->getOperand(2)))
-        return replaceOperand(SI, 1, Sel->getOperand(1));
-    }
-  }
+  if (!NullPointerIsDefined(SI.getFunction(), SI.getPointerAddressSpace()))
+    if (Value *V = simplifyNonNullOperand(Ptr))
+      return replaceOperand(SI, 1, V);
 
   return nullptr;
 }
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 5621511570b58..d3af06f63fcd2 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -3587,7 +3587,15 @@ Instruction *InstCombinerImpl::visitFree(CallInst &FI, Value *Op) {
 
 Instruction *InstCombinerImpl::visitReturnInst(ReturnInst &RI) {
   Value *RetVal = RI.getReturnValue();
-  if (!RetVal || !AttributeFuncs::isNoFPClassCompatibleType(RetVal->getType()))
+  if (!RetVal)
+    return nullptr;
+
+  if (RetVal->getType()->isPointerTy() && RI.getFunction()->isReturnNonNull()) {
+    if (Value *V = simplifyNonNullOperand(RetVal))
+      return replaceOperand(RI, 0, V);
+  }
+
+  if (!AttributeFuncs::isNoFPClassCompatibleType(RetVal->getType()))
     return nullptr;
 
   Function *F = RI.getFunction();
diff --git a/llvm/test/Transforms/InstCombine/nonnull-select.ll b/llvm/test/Transforms/InstCombine/nonnull-select.ll
index 3fab2dfb41a42..cc000b4c88164 100644
--- a/llvm/test/Transforms/InstCombine/nonnull-select.ll
+++ b/llvm/test/Transforms/InstCombine/nonnull-select.ll
@@ -5,10 +5,7 @@
 
 define nonnull ptr @pr48975(ptr %.0) {
 ; CHECK-LABEL: @pr48975(
-; CHECK-NEXT:    [[DOT1:%.*]] = load ptr, ptr [[DOT0:%.*]], align 8
-; CHECK-NEXT:    [[DOT2:%.*]] = icmp eq ptr [[DOT1]], null
-; CHECK-NEXT:    [[DOT4:%.*]] = select i1 [[DOT2]], ptr null, ptr [[DOT0]]
-; CHECK-NEXT:    ret ptr [[DOT4]]
+; CHECK-NEXT:    ret ptr [[DOT4:%.*]]
 ;
   %.1 = load ptr, ptr %.0, align 8
   %.2 = icmp eq ptr %.1, null
@@ -18,8 +15,7 @@ define nonnull ptr @pr48975(ptr %.0) {
 
 define nonnull ptr @nonnull_ret(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_ret(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr [[P:%.*]], ptr null
-; CHECK-NEXT:    ret ptr [[RES]]
+; CHECK-NEXT:    ret ptr [[RES:%.*]]
 ;
   %res = select i1 %cond, ptr %p, ptr null
   ret ptr %res
@@ -27,8 +23,7 @@ define nonnull ptr @nonnull_ret(i1 %cond, ptr %p) {
 
 define nonnull ptr @nonnull_ret2(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_ret2(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
-; CHECK-NEXT:    ret ptr [[RES]]
+; CHECK-NEXT:    ret ptr [[RES:%.*]]
 ;
   %res = select i1 %cond, ptr null, ptr %p
   ret ptr %res
@@ -36,8 +31,7 @@ define nonnull ptr @nonnull_ret2(i1 %cond, ptr %p) {
 
 define nonnull noundef ptr @nonnull_noundef_ret(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_noundef_ret(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr [[P:%.*]], ptr null
-; CHECK-NEXT:    ret ptr [[RES]]
+; CHECK-NEXT:    ret ptr [[RES:%.*]]
 ;
   %res = select i1 %cond, ptr %p, ptr null
   ret ptr %res
@@ -45,8 +39,7 @@ define nonnull noundef ptr @nonnull_noundef_ret(i1 %cond, ptr %p) {
 
 define nonnull noundef ptr @nonnull_noundef_ret2(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_noundef_ret2(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
-; CHECK-NEXT:    ret ptr [[RES]]
+; CHECK-NEXT:    ret ptr [[RES:%.*]]
 ;
   %res = select i1 %cond, ptr null, ptr %p
   ret ptr %res
@@ -55,8 +48,7 @@ define nonnull noundef ptr @nonnull_noundef_ret2(i1 %cond, ptr %p) {
 
 define void @nonnull_call(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_call(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr [[P:%.*]], ptr null
-; CHECK-NEXT:    call void @f(ptr nonnull [[RES]])
+; CHECK-NEXT:    call void @f(ptr nonnull [[RES:%.*]])
 ; CHECK-NEXT:    ret void
 ;
   %res = select i1 %cond, ptr %p, ptr null
@@ -66,8 +58,7 @@ define void @nonnull_call(i1 %cond, ptr %p) {
 
 define void @nonnull_call2(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_call2(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
-; CHECK-NEXT:    call void @f(ptr nonnull [[RES]])
+; CHECK-NEXT:    call void @f(ptr nonnull [[RES:%.*]])
 ; CHECK-NEXT:    ret void
 ;
   %res = select i1 %cond, ptr null, ptr %p
@@ -77,8 +68,7 @@ define void @nonnull_call2(i1 %cond, ptr %p) {
 
 define void @nonnull_noundef_call(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_noundef_call(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr [[P:%.*]], ptr null
-; CHECK-NEXT:    call void @f(ptr noundef nonnull [[RES]])
+; CHECK-NEXT:    call void @f(ptr noundef nonnull [[RES:%.*]])
 ; CHECK-NEXT:    ret void
 ;
   %res = select i1 %cond, ptr %p, ptr null
@@ -88,8 +78,7 @@ define void @nonnull_noundef_call(i1 %cond, ptr %p) {
 
 define void @nonnull_noundef_call2(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_noundef_call2(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
-; CHECK-NEXT:    call void @f(ptr noundef nonnull [[RES]])
+; CHECK-NEXT:    call void @f(ptr noundef nonnull [[RES:%.*]])
 ; CHECK-NEXT:    ret void
 ;
   %res = select i1 %cond, ptr null, ptr %p
diff --git a/llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll b/llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll
index d8ef0723cf09e..f6bf57a678786 100644
--- a/llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll
+++ b/llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll
@@ -1,24 +1,21 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
 ; RUN: opt -passes='instcombine,early-cse<memssa>' -S %s | FileCheck %s
 
-; FIXME: We can remove the store instruction in the exit block
 define i32 @load_store_sameval(ptr %p, i1 %cond1, i1 %cond2) {
 ; CHECK-LABEL: define i32 @load_store_sameval(
 ; CHECK-SAME: ptr [[P:%.*]], i1 [[COND1:%.*]], i1 [[COND2:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    [[SPEC_SELECT:%.*]] = select i1 [[COND1]], ptr null, ptr [[P]]
-; CHECK-NEXT:    [[PRE:%.*]] = load i32, ptr [[SPEC_SELECT]], align 4
+; CHECK-NEXT:    [[PRE:%.*]] = load i32, ptr [[P]], align 4
 ; CHECK-NEXT:    br label %[[BLOCK:.*]]
 ; CHECK:       [[BLOCK]]:
 ; CHECK-NEXT:    br label %[[BLOCK2:.*]]
 ; CHECK:       [[BLOCK2]]:
 ; CHECK-NEXT:    br i1 [[COND2]], label %[[BLOCK3:.*]], label %[[EXIT:.*]]
 ; CHECK:       [[BLOCK3]]:
-; CHECK-NEXT:    [[LOAD:%.*]] = load double, ptr [[SPEC_SELECT]], align 8
+; CHECK-NEXT:    [[LOAD:%.*]] = load double, ptr [[P]], align 8
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp une double [[LOAD]], 0.000000e+00
 ; CHECK-NEXT:    br i1 [[CMP]], label %[[BLOCK]], label %[[BLOCK2]]
 ; CHECK:       [[EXIT]]:
-; CHECK-NEXT:    store i32 [[PRE]], ptr [[P]], align 4
 ; CHECK-NEXT:    ret i32 0
 ;
 entry:
diff --git a/llvm/test/Transforms/PhaseOrdering/memset-combine.ll b/llvm/test/Transforms/PhaseOrdering/memset-combine.ll
new file mode 100644
index 0000000000000..d1de11258ed91
--- /dev/null
+++ b/llvm/test/Transforms/PhaseOrdering/memset-combine.ll
@@ -0,0 +1,20 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+
+; RUN: opt < %s -passes=instcombine,memcpyopt -S | FileCheck %s
+
+; FIXME: These two memset calls should be merged into a single one.
+define void @merge_memset(ptr %p, i1 %cond) {
+; CHECK-LABEL: define void @merge_memset(
+; CHECK-SAME: ptr [[P:%.*]], i1 [[COND:%.*]]) {
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[COND]], ptr null, ptr [[P]]
+; CHECK-NEXT:    tail call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(4096) [[P]], i8 0, i64 4096, i1 false)
+; CHECK-NEXT:    [[OFF:%.*]] = getelementptr inbounds nuw i8, ptr [[SEL]], i64 4096
+; CHECK-NEXT:    tail call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(768) [[OFF]], i8 0, i64 768, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %sel = select i1 %cond, ptr null, ptr %p
+  tail call void @llvm.memset.p0.i64(ptr noundef nonnull %sel, i8 0, i64 4096, i1 false)
+  %off = getelementptr inbounds nuw i8, ptr %sel, i64 4096
+  tail call void @llvm.memset.p0.i64(ptr noundef nonnull %off, i8 0, i64 768, i1 false)
+  ret void
+}

@nikic
Copy link
Contributor

nikic commented Feb 21, 2025

replaceOperand(Call, ArgNo, Res);
Changed = true;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can else this instead of querying nonnull again? (Will no longer infer nonnull for dereferenceable, but we shouldn't need to ?)

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@asmok-g
Copy link

asmok-g commented Mar 7, 2025

This is causing a probable miscompile.

I'm working on a repro. Bisecting for the exact file and exact function that causes the miscompile when optimized; the IR diff looks like, before:

; Function Attrs: mustprogress nounwind uwtable
define linkonce_odr dso_local void @_ZN3gvr21WST12GetTransformEv(ptr dead_on_unwind noalias writable sret(%"M") align 4 %0, ptr noundef nonnull align 8 dereferenceable(32) %1) unnamed_addr #0 comdat align 2 {
  %3 = getelementptr inbounds nuw i8, ptr %1, i64 16
  %4 = getelementptr inbounds nuw i8, ptr %1, i64 24
  %5 = load ptr, ptr %4, align 8, !tbaa !51, !nonnull !50, !noundef !50
  %6 = tail call noundef ptr @_ZNSt3__u19__shared_weak_count4lockEv(ptr noundef nonnull align 8 dereferenceable(24) %5) #12
  %7 = icmp eq ptr %6, null
  %8 = load ptr, ptr %3, align 8
  %9 = select i1 %7, ptr null, ptr %8
  %10 = load ptr, ptr %9, align 8, !tbaa !3
  %11 = getelementptr inbounds nuw i8, ptr %10, i64 24
  %12 = load ptr, ptr %11, align 8
  tail call void %12(ptr dead_on_unwind writable sret(%"M") align 4 %0, ptr noundef nonnull align 8 dereferenceable(112) %9) #12
  br i1 %7, label %21, label %13

13:                                               ; preds = %2
  %14 = getelementptr inbounds nuw i8, ptr %6, i64 8
  %15 = atomicrmw add ptr %14, i64 -1 acq_rel, align 8
  %16 = icmp eq i64 %15, 0
  br i1 %16, label %17, label %21

17:                                               ; preds = %13
  %18 = load ptr, ptr %6, align 8, !tbaa !3
  %19 = getelementptr inbounds nuw i8, ptr %18, i64 16
  %20 = load ptr, ptr %19, align 8
  tail call void %20(ptr noundef nonnull align 8 dereferenceable(24) %6) #12
  tail call void @_ZNSt3__u19__shared_weak_count14__release_weakEv(ptr noundef nonnull align 8 dereferenceable(24) %6) #12
  br label %21

21:                                               ; preds = %2, %13, %17
  ret void
}

after:

; Function Attrs: mustprogress nounwind uwtable
define linkonce_odr dso_local void @_ZN3gvr21WST12GetTransformEv(ptr dead_on_unwind noalias writable sret(%"M") align 4 %0, ptr noundef nonnull align 8 dereferenceable(32) %1) unnamed_addr #0 comdat align 2 {
  %3 = getelementptr inbounds nuw i8, ptr %1, i64 16
  %4 = getelementptr inbounds nuw i8, ptr %1, i64 24
  %5 = load ptr, ptr %4, align 8, !tbaa !51, !nonnull !50, !noundef !50
  %6 = tail call noundef ptr @_ZNSt3__u19__shared_weak_count4lockEv(ptr noundef nonnull align 8 dereferenceable(24) %5) #12
  %7 = load ptr, ptr %3, align 8
  %8 = load ptr, ptr %7, align 8, !tbaa !3
  %9 = getelementptr inbounds nuw i8, ptr %8, i64 24
  %10 = load ptr, ptr %9, align 8
  tail call void %10(ptr dead_on_unwind writable sret(%"M") align 4 %0, ptr noundef nonnull align 8 dereferenceable(112) %7) #12
  %11 = icmp eq ptr %6, null
  br i1 %11, label %20, label %12

12:                                               ; preds = %2
  %13 = getelementptr inbounds nuw i8, ptr %6, i64 8
  %14 = atomicrmw add ptr %13, i64 -1 acq_rel, align 8
  %15 = icmp eq i64 %14, 0
  br i1 %15, label %16, label %20

16:                                               ; preds = %12
  %17 = load ptr, ptr %6, align 8, !tbaa !3
  %18 = getelementptr inbounds nuw i8, ptr %17, i64 16
  %19 = load ptr, ptr %18, align 8
  tail call void %19(ptr noundef nonnull align 8 dereferenceable(24) %6) #12
  tail call void @_ZNSt3__u19__shared_weak_count14__release_weakEv(ptr noundef nonnull align 8 dereferenceable(24) %6) #12
  br label %20

20:                                               ; preds = %2, %12, %16
  ret void
}

The clang invocation to reproduce is clang -fno-exceptions -O3 '-std=gnu++20' pre.ii -emit-llvm -S -o case.ll

The entire preprocessed file that i'm reducing has other changes, but i thought i'd show the diff for the specific function will we have a completer reduced case.

@dtcxzyw
Copy link
Member Author

dtcxzyw commented Mar 7, 2025

This is causing a probable miscompile.

I'm working on a repro. Bisecting for the exact file and exact function that causes the miscompile when optimized; the IR diff looks like, before:

; Function Attrs: mustprogress nounwind uwtable
define linkonce_odr dso_local void @_ZN3gvr21WST12GetTransformEv(ptr dead_on_unwind noalias writable sret(%"M") align 4 %0, ptr noundef nonnull align 8 dereferenceable(32) %1) unnamed_addr #0 comdat align 2 {
  %3 = getelementptr inbounds nuw i8, ptr %1, i64 16
  %4 = getelementptr inbounds nuw i8, ptr %1, i64 24
  %5 = load ptr, ptr %4, align 8, !tbaa !51, !nonnull !50, !noundef !50
  %6 = tail call noundef ptr @_ZNSt3__u19__shared_weak_count4lockEv(ptr noundef nonnull align 8 dereferenceable(24) %5) #12
  %7 = icmp eq ptr %6, null
  %8 = load ptr, ptr %3, align 8
  %9 = select i1 %7, ptr null, ptr %8
  %10 = load ptr, ptr %9, align 8, !tbaa !3
  %11 = getelementptr inbounds nuw i8, ptr %10, i64 24
  %12 = load ptr, ptr %11, align 8
  tail call void %12(ptr dead_on_unwind writable sret(%"M") align 4 %0, ptr noundef nonnull align 8 dereferenceable(112) %9) #12
  br i1 %7, label %21, label %13

13:                                               ; preds = %2
  %14 = getelementptr inbounds nuw i8, ptr %6, i64 8
  %15 = atomicrmw add ptr %14, i64 -1 acq_rel, align 8
  %16 = icmp eq i64 %15, 0
  br i1 %16, label %17, label %21

17:                                               ; preds = %13
  %18 = load ptr, ptr %6, align 8, !tbaa !3
  %19 = getelementptr inbounds nuw i8, ptr %18, i64 16
  %20 = load ptr, ptr %19, align 8
  tail call void %20(ptr noundef nonnull align 8 dereferenceable(24) %6) #12
  tail call void @_ZNSt3__u19__shared_weak_count14__release_weakEv(ptr noundef nonnull align 8 dereferenceable(24) %6) #12
  br label %21

21:                                               ; preds = %2, %13, %17
  ret void
}

after:

; Function Attrs: mustprogress nounwind uwtable
define linkonce_odr dso_local void @_ZN3gvr21WST12GetTransformEv(ptr dead_on_unwind noalias writable sret(%"M") align 4 %0, ptr noundef nonnull align 8 dereferenceable(32) %1) unnamed_addr #0 comdat align 2 {
  %3 = getelementptr inbounds nuw i8, ptr %1, i64 16
  %4 = getelementptr inbounds nuw i8, ptr %1, i64 24
  %5 = load ptr, ptr %4, align 8, !tbaa !51, !nonnull !50, !noundef !50
  %6 = tail call noundef ptr @_ZNSt3__u19__shared_weak_count4lockEv(ptr noundef nonnull align 8 dereferenceable(24) %5) #12
  %7 = load ptr, ptr %3, align 8
  %8 = load ptr, ptr %7, align 8, !tbaa !3
  %9 = getelementptr inbounds nuw i8, ptr %8, i64 24
  %10 = load ptr, ptr %9, align 8
  tail call void %10(ptr dead_on_unwind writable sret(%"M") align 4 %0, ptr noundef nonnull align 8 dereferenceable(112) %7) #12
  %11 = icmp eq ptr %6, null
  br i1 %11, label %20, label %12

12:                                               ; preds = %2
  %13 = getelementptr inbounds nuw i8, ptr %6, i64 8
  %14 = atomicrmw add ptr %13, i64 -1 acq_rel, align 8
  %15 = icmp eq i64 %14, 0
  br i1 %15, label %16, label %20

16:                                               ; preds = %12
  %17 = load ptr, ptr %6, align 8, !tbaa !3
  %18 = getelementptr inbounds nuw i8, ptr %17, i64 16
  %19 = load ptr, ptr %18, align 8
  tail call void %19(ptr noundef nonnull align 8 dereferenceable(24) %6) #12
  tail call void @_ZNSt3__u19__shared_weak_count14__release_weakEv(ptr noundef nonnull align 8 dereferenceable(24) %6) #12
  br label %20

20:                                               ; preds = %2, %12, %16
  ret void
}

The clang invocation to reproduce is clang -fno-exceptions -O3 '-std=gnu++20' pre.ii -emit-llvm -S -o case.ll

The entire preprocessed file that i'm reducing has other changes, but i thought i'd show the diff for the specific function will we have a completer reduced case.

IIRC this transformation is correct. If %9 evaluates to null, we will hit a UB at the following load/call instructions.

@asmok-g
Copy link

asmok-g commented Mar 10, 2025

Here's a reduction attempt of the whole preprocessed (might be over-reduced):

namespace {
template <bool, class a> using b = a;
template <int c> struct ab {
  static const int aa = c;
};
template <class d, class e> struct f : ab<__is_convertible(d, e)> {};
template <int> struct g {
  template <class, class h, class... i>
  using j = g<!h::aa>::template j<h, i...>;
};
template <> struct g<false> {
  template <class k, class> using j = k;
};
template <class... l> using m = g<sizeof...(l)>::template j<ab<false>, l...>;
template <class n, class a> struct ac : m<f<n, a>, int> {};
template <class a> struct o {
  typedef a p;
  p *q;
  o() : q() {}
  p *operator->() { return q; }
};
template <class a> struct r {
  a *q;
  int *s;
  template <class n, b<ac<n, a>::aa, int> = 0> r(o<n>);
  o<a> ad() {
    o<a> t;
    if (s)
      t.q = q;
    return t;
  }
};
} // namespace
struct ae {
  int u[];
};
struct v {
  virtual ae y();
};
o<v> af;
struct z : v {
  z() : w(af) {}
  ae y() {
    o x = w.ad();
    return x->y();
  }
  r<v> w;
};
void ag() { z(); }

cmd: clang -fno-exceptions -O3 '-std=gnu++20' reduced.ii -emit-llvm -S -o

YutongZhuu can you please take another look?

@dtcxzyw
Copy link
Member Author

dtcxzyw commented Mar 10, 2025

Here's a reduction attempt of the whole preprocessed (might be over-reduced):

namespace {
template <bool, class a> using b = a;
template <int c> struct ab {
  static const int aa = c;
};
template <class d, class e> struct f : ab<__is_convertible(d, e)> {};
template <int> struct g {
  template <class, class h, class... i>
  using j = g<!h::aa>::template j<h, i...>;
};
template <> struct g<false> {
  template <class k, class> using j = k;
};
template <class... l> using m = g<sizeof...(l)>::template j<ab<false>, l...>;
template <class n, class a> struct ac : m<f<n, a>, int> {};
template <class a> struct o {
  typedef a p;
  p *q;
  o() : q() {}
  p *operator->() { return q; }
};
template <class a> struct r {
  a *q;
  int *s;
  template <class n, b<ac<n, a>::aa, int> = 0> r(o<n>);
  o<a> ad() {
    o<a> t;
    if (s)
      t.q = q;
    return t;
  }
};
} // namespace
struct ae {
  int u[];
};
struct v {
  virtual ae y();
};
o<v> af;
struct z : v {
  z() : w(af) {}
  ae y() {
    o x = w.ad();
    return x->y();
  }
  r<v> w;
};
void ag() { z(); }

cmd: clang -fno-exceptions -O3 '-std=gnu++20' reduced.ii -emit-llvm -S -o

YutongZhuu can you please take another look?

Can you please provide a single-file, executable reproducer without UB?

@asmok-g
Copy link

asmok-g commented Mar 10, 2025

I'm working on it; but to give a bit of explanation: the test that fails now is a death-test. It's heap-use-after-free access to a weak_ptr. The problem is, after this patch the test fails to die. It returns the object that should have already been destroyed. I'll try to get a better repro.

@asmok-g
Copy link

asmok-g commented Mar 10, 2025

Sorry for the false alarm here. The tests were just testing UB. Will update the tests. Sorry again for the inconvenience.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants