-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[SimplifyCFG] Switch to use paramHasNonNullAttr
#125383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…n `Use` (#125519) Address comment #125383 (comment)
…` to work on `Use` (#125519) Address comment llvm/llvm-project#125383 (comment)
…n `Use` (llvm#125519) Address comment llvm#125383 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs a rebase?
7e90465
to
8cdfdae
Compare
8cdfdae
to
244bf1f
Compare
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-llvm-ir Author: Yingwei Zheng (dtcxzyw) ChangesFull diff: https://github.com/llvm/llvm-project/pull/125383.diff 3 Files Affected:
diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index 8e47e3c7b3a7c..61070aa79b15d 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -1839,7 +1839,11 @@ class CallBase : public Instruction {
/// Extract the number of dereferenceable bytes for a call or
/// parameter (0=unknown).
uint64_t getParamDereferenceableBytes(unsigned i) const {
- return Attrs.getParamDereferenceableBytes(i);
+ uint64_t Bytes = Attrs.getParamDereferenceableBytes(i);
+ if (const Function *F = getCalledFunction())
+ Bytes =
+ std::max(Bytes, F->getAttributes().getParamDereferenceableBytes(i));
+ return Bytes;
}
/// Extract the number of dereferenceable_or_null bytes for a call
@@ -1857,7 +1861,11 @@ class CallBase : public Instruction {
/// Extract the number of dereferenceable_or_null bytes for a
/// parameter (0=unknown).
uint64_t getParamDereferenceableOrNullBytes(unsigned i) const {
- return Attrs.getParamDereferenceableOrNullBytes(i);
+ uint64_t Bytes = Attrs.getParamDereferenceableOrNullBytes(i);
+ if (const Function *F = getCalledFunction())
+ Bytes = std::max(
+ Bytes, F->getAttributes().getParamDereferenceableOrNullBytes(i));
+ return Bytes;
}
/// Extract a test mask for disallowed floating-point value classes for the
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 7f53aa7d4f73d..ea171010c7507 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -8214,8 +8214,8 @@ static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I, bool PtrValu
if (CB->isArgOperand(&Use)) {
unsigned ArgIdx = CB->getArgOperandNo(&Use);
// Passing null to a nonnnull+noundef argument is undefined.
- if (C->isNullValue() && CB->isPassingUndefUB(ArgIdx) &&
- CB->paramHasAttr(ArgIdx, Attribute::NonNull))
+ if (isa<ConstantPointerNull>(C) &&
+ CB->paramHasNonNullAttr(ArgIdx, /*AllowUndefOrPoison=*/false))
return !PtrValueMayBeModified;
// Passing undef to a noundef argument is undefined.
if (isa<UndefValue>(C) && CB->isPassingUndefUB(ArgIdx))
diff --git a/llvm/test/Transforms/SimplifyCFG/UnreachableEliminate.ll b/llvm/test/Transforms/SimplifyCFG/UnreachableEliminate.ll
index aae1ab032f36e..2da5d18b63f49 100644
--- a/llvm/test/Transforms/SimplifyCFG/UnreachableEliminate.ll
+++ b/llvm/test/Transforms/SimplifyCFG/UnreachableEliminate.ll
@@ -238,7 +238,7 @@ else:
}
declare ptr @fn_nonnull_noundef_arg(ptr nonnull noundef %p)
-declare ptr @fn_nonnull_deref_arg(ptr nonnull dereferenceable(4) %p)
+declare ptr @fn_deref_arg(ptr dereferenceable(4) %p)
declare ptr @fn_nonnull_deref_or_null_arg(ptr nonnull dereferenceable_or_null(4) %p)
declare ptr @fn_nonnull_arg(ptr nonnull %p)
declare ptr @fn_noundef_arg(ptr noundef %p)
@@ -271,7 +271,7 @@ define void @test9_deref(i1 %X, ptr %Y) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = xor i1 [[X:%.*]], true
; CHECK-NEXT: call void @llvm.assume(i1 [[TMP0]])
-; CHECK-NEXT: [[TMP1:%.*]] = call ptr @fn_nonnull_deref_arg(ptr [[Y:%.*]])
+; CHECK-NEXT: [[TMP1:%.*]] = call ptr @fn_deref_arg(ptr [[Y:%.*]])
; CHECK-NEXT: ret void
;
entry:
@@ -282,7 +282,7 @@ if:
else:
%phi = phi ptr [ %Y, %entry ], [ null, %if ]
- call ptr @fn_nonnull_deref_arg(ptr %phi)
+ call ptr @fn_deref_arg(ptr %phi)
ret void
}
@@ -290,9 +290,8 @@ else:
define void @test9_deref_or_null(i1 %X, ptr %Y) {
; CHECK-LABEL: @test9_deref_or_null(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[TMP0:%.*]] = xor i1 [[X:%.*]], true
-; CHECK-NEXT: call void @llvm.assume(i1 [[TMP0]])
-; CHECK-NEXT: [[TMP1:%.*]] = call ptr @fn_nonnull_deref_or_null_arg(ptr [[Y:%.*]])
+; CHECK-NEXT: [[Y:%.*]] = select i1 [[X:%.*]], ptr null, ptr [[Y1:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = call ptr @fn_nonnull_deref_or_null_arg(ptr [[Y]])
; CHECK-NEXT: ret void
;
entry:
|
244bf1f
to
5b59246
Compare
Ping. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this adds some measurable overhead: https://llvm-compile-time-tracker.com/compare.php?from=2a3afa2feb90844ad0f8b0bc57663e2aec06cd0a&to=5b59246712a5c9446a526503818b7f86b824f03c&stat=instructions:u
Can you try changing the check in paramHasNonNullAttr to only use paramHasAttr instead of getParamDereferenceableBytes? This should be a more efficient way to check for non-zero dereferenceable bytes.
This is causing crashes on bootstrap builds with RVV enabled, with the assertion at the top of paramHasNonNullAttr failing To reproduce: cat - <<'EOF' > crash.ll
$ cat reduced.ll
target datalayout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128"
target triple = "riscv64-unknown-linux-gnu"
define ptr @_ZN4llvm14RegionInfoBaseINS_12RegionTraitsINS_8FunctionEEEEaSEOS4_(ptr %this) {
entry:
br label %for.body.i.i
for.cond.cleanup.i.i: ; preds = %for.body.i.i
ret ptr null
for.body.i.i: ; preds = %for.body.i.i, %entry
%P.024.i.i = phi ptr [ %incdec.ptr.i.i, %for.body.i.i ], [ null, %entry ]
store ptr null, ptr %P.024.i.i, align 8
%incdec.ptr.i.i = getelementptr i8, ptr %P.024.i.i, i64 16
%cmp15.not.i.i = icmp eq ptr %P.024.i.i, %this
br i1 %cmp15.not.i.i, label %for.cond.cleanup.i.i, label %for.body.i.i
}
EOF
/path/to/your/clang -O3 --target=riscv64-linux-gnu -march=rva23u64 crash.ll -fno-crash-diagnostics -mllvm -print-after-all -mllvm -debug The last IR dump before the crash: ; Function Attrs: nofree norecurse nosync nounwind memory(write, inaccessiblemem: none)
define noalias noundef ptr @_ZN4llvm14RegionInfoBaseINS_12RegionTraitsINS_8FunctionEEEEaSEOS4_(ptr readnone captures(address) %this) local_unnamed_addr #0 {
entry:
%this1 = ptrtoint ptr %this to i64
%0 = lshr i64 %this1, 4
%1 = add nuw nsw i64 %0, 1
%2 = call i64 @llvm.vscale.i64()
%3 = shl i64 %2, 1
%4 = call i64 @llvm.umax.i64(i64 %3, i64 12)
%min.iters.check = icmp ult i64 %1, %4
br i1 %min.iters.check, label %scalar.ph, label %vector.scevcheck
vector.scevcheck: ; preds = %entry
%5 = and i64 %this1, 15
%ident.check.not = icmp eq i64 %5, 0
br i1 %ident.check.not, label %vector.ph, label %scalar.ph
vector.ph: ; preds = %vector.scevcheck
%6 = call i64 @llvm.vscale.i64()
%7 = shl i64 %6, 1
%n.mod.vf = urem i64 %1, %7
%n.vec = sub nsw i64 %1, %n.mod.vf
%8 = call i64 @llvm.vscale.i64()
%9 = shl i64 %8, 1
%10 = shl i64 %n.vec, 4
%11 = getelementptr i8, ptr null, i64 %10
br label %vector.body
vector.body: ; preds = %vector.body, %vector.ph
%index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
%pointer.phi = phi ptr [ null, %vector.ph ], [ %ptr.ind, %vector.body ]
%12 = shl i64 %8, 5
%13 = call <vscale x 2 x i64> @llvm.stepvector.nxv2i64()
%14 = shl <vscale x 2 x i64> %13, splat (i64 4)
%vector.gep = getelementptr i8, ptr %pointer.phi, <vscale x 2 x i64> %14
call void @llvm.masked.scatter.nxv2p0.nxv2p0(<vscale x 2 x ptr> zeroinitializer, <vscale x 2 x ptr> %vector.gep, i32 8, <vscale x 2 x i1> splat (i1 true))
%index.next = add nuw i64 %index, %9
%ptr.ind = getelementptr i8, ptr %pointer.phi, i64 %12
%15 = icmp eq i64 %index.next, %n.vec
br i1 %15, label %middle.block, label %vector.body, !llvm.loop !0
middle.block: ; preds = %vector.body
%cmp.n = icmp eq i64 %n.mod.vf, 0
br i1 %cmp.n, label %for.cond.cleanup.i.i, label %scalar.ph
scalar.ph: ; preds = %vector.scevcheck, %entry, %middle.block
%bc.resume.val = phi ptr [ %11, %middle.block ], [ null, %entry ], [ null, %vector.scevcheck ]
br label %for.body.i.i
for.cond.cleanup.i.i.loopexit: ; preds = %for.body.i.i
br label %for.cond.cleanup.i.i
for.cond.cleanup.i.i: ; preds = %for.cond.cleanup.i.i.loopexit, %middle.block
ret ptr null
for.body.i.i: ; preds = %scalar.ph, %for.body.i.i
%P.024.i.i = phi ptr [ %incdec.ptr.i.i, %for.body.i.i ], [ %bc.resume.val, %scalar.ph ]
store ptr null, ptr %P.024.i.i, align 8
%incdec.ptr.i.i = getelementptr i8, ptr %P.024.i.i, i64 16
%cmp15.not.i.i = icmp eq ptr %P.024.i.i, %this
br i1 %cmp15.not.i.i, label %for.cond.cleanup.i.i.loopexit, label %for.body.i.i, !llvm.loop !3
} |
I will have a look. |
No description provided.