-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[Attributor] Keep track of reached returns in AAPointerInfo #107479
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Johannes Doerfert (jdoerfert) ChangesInstead of visiting call sites in Attribute::checkForAllUses, we now Full diff: https://github.com/llvm/llvm-project/pull/107479.diff 5 Files Affected:
diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h
index 5844fb8b0f8938..6ab63ba582c546 100644
--- a/llvm/include/llvm/Transforms/IPO/Attributor.h
+++ b/llvm/include/llvm/Transforms/IPO/Attributor.h
@@ -6119,6 +6119,7 @@ struct AAPointerInfo : public AbstractAttribute {
virtual const_bin_iterator begin() const = 0;
virtual const_bin_iterator end() const = 0;
virtual int64_t numOffsetBins() const = 0;
+ virtual bool reachesReturn() const = 0;
/// Call \p CB on all accesses that might interfere with \p Range and return
/// true if all such accesses were known and the callback returned true for
diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp
index 38b61b6a88357c..56d1133b25549a 100644
--- a/llvm/lib/Transforms/IPO/Attributor.cpp
+++ b/llvm/lib/Transforms/IPO/Attributor.cpp
@@ -1852,22 +1852,6 @@ bool Attributor::checkForAllUses(
User &Usr = *U->getUser();
AddUsers(Usr, /* OldUse */ nullptr);
-
- auto *RI = dyn_cast<ReturnInst>(&Usr);
- if (!RI)
- continue;
-
- Function &F = *RI->getFunction();
- auto CallSitePred = [&](AbstractCallSite ACS) {
- return AddUsers(*ACS.getInstruction(), U);
- };
- if (!checkForAllCallSites(CallSitePred, F, /* RequireAllCallSites */ true,
- &QueryingAA, UsedAssumedInformation)) {
- LLVM_DEBUG(dbgs() << "[Attributor] Could not follow return instruction "
- "to all call sites: "
- << *RI << "\n");
- return false;
- }
}
return true;
diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
index 69d29b6c042349..36e7049b18ae3f 100644
--- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
+++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
@@ -827,6 +827,7 @@ struct AA::PointerInfo::State : public AbstractState {
AccessList = R.AccessList;
OffsetBins = R.OffsetBins;
RemoteIMap = R.RemoteIMap;
+ ReachesReturn = R.ReachesReturn;
return *this;
}
@@ -837,6 +838,7 @@ struct AA::PointerInfo::State : public AbstractState {
std::swap(AccessList, R.AccessList);
std::swap(OffsetBins, R.OffsetBins);
std::swap(RemoteIMap, R.RemoteIMap);
+ std::swap(ReachesReturn, R.ReachesReturn);
return *this;
}
@@ -878,11 +880,16 @@ struct AA::PointerInfo::State : public AbstractState {
AAPointerInfo::OffsetBinsTy OffsetBins;
DenseMap<const Instruction *, SmallVector<unsigned>> RemoteIMap;
+ /// Flag to determine if the underlying pointer is reaching a return statement
+ /// in the associated function or not. Returns in other functions cause
+ /// invalidation.
+ bool ReachesReturn = false;
+
/// See AAPointerInfo::forallInterferingAccesses.
bool forallInterferingAccesses(
AA::RangeTy Range,
function_ref<bool(const AAPointerInfo::Access &, bool)> CB) const {
- if (!isValidState())
+ if (!isValidState() || ReachesReturn)
return false;
for (const auto &It : OffsetBins) {
@@ -904,7 +911,7 @@ struct AA::PointerInfo::State : public AbstractState {
Instruction &I,
function_ref<bool(const AAPointerInfo::Access &, bool)> CB,
AA::RangeTy &Range) const {
- if (!isValidState())
+ if (!isValidState() || ReachesReturn)
return false;
auto LocalList = RemoteIMap.find(&I);
@@ -1071,7 +1078,8 @@ struct AAPointerInfoImpl
return std::string("PointerInfo ") +
(isValidState() ? (std::string("#") +
std::to_string(OffsetBins.size()) + " bins")
- : "<invalid>");
+ : "<invalid>") +
+ (ReachesReturn ? " (returned)" : "");
}
/// See AbstractAttribute::manifest(...).
@@ -1084,6 +1092,7 @@ struct AAPointerInfoImpl
virtual int64_t numOffsetBins() const override {
return State::numOffsetBins();
}
+ virtual bool reachesReturn() const override { return ReachesReturn; }
bool forallInterferingAccesses(
AA::RangeTy Range,
@@ -1373,6 +1382,7 @@ struct AAPointerInfoImpl
const auto &OtherAAImpl = static_cast<const AAPointerInfoImpl &>(OtherAA);
bool IsByval = OtherAAImpl.getAssociatedArgument()->hasByValAttr();
+ ReachesReturn = OtherAAImpl.ReachesReturn;
// Combine the accesses bin by bin.
ChangeStatus Changed = ChangeStatus::UNCHANGED;
@@ -1397,7 +1407,8 @@ struct AAPointerInfoImpl
}
ChangeStatus translateAndAddState(Attributor &A, const AAPointerInfo &OtherAA,
- const OffsetInfo &Offsets, CallBase &CB) {
+ const OffsetInfo &Offsets, CallBase &CB,
+ bool IsMustAcc) {
using namespace AA::PointerInfo;
if (!OtherAA.getState().isValidState() || !isValidState())
return indicatePessimisticFixpoint();
@@ -1410,6 +1421,8 @@ struct AAPointerInfoImpl
for (const auto &It : State) {
for (auto Index : It.getSecond()) {
const auto &RAcc = State.getAccess(Index);
+ if (!IsMustAcc && RAcc.isAssumption())
+ continue;
for (auto Offset : Offsets) {
auto NewRanges = Offset == AA::RangeTy::Unknown
? AA::RangeTy::getUnknown()
@@ -1417,9 +1430,11 @@ struct AAPointerInfoImpl
if (!NewRanges.isUnknown()) {
NewRanges.addToAllOffsets(Offset);
}
- Changed |=
- addAccess(A, NewRanges, CB, RAcc.getContent(), RAcc.getKind(),
- RAcc.getType(), RAcc.getRemoteInst());
+ AccessKind AK = RAcc.getKind();
+ if (!IsMustAcc)
+ AK = AccessKind((AK & ~AK_MUST) | AK_MAY);
+ Changed |= addAccess(A, NewRanges, CB, RAcc.getContent(), AK,
+ RAcc.getType(), RAcc.getRemoteInst());
}
}
}
@@ -1661,8 +1676,13 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) {
}
if (isa<PtrToIntInst>(Usr))
return false;
- if (isa<CastInst>(Usr) || isa<SelectInst>(Usr) || isa<ReturnInst>(Usr))
+ if (isa<CastInst>(Usr) || isa<SelectInst>(Usr))
return HandlePassthroughUser(Usr, CurPtr, Follow);
+ // Returns are allowed if they are in the associated functions. Users can
+ // then check the call site return. Returns from other functions can't be
+ // tracked and are cause for invalidation.
+ if (auto *RI = dyn_cast<ReturnInst>(Usr))
+ return ReachesReturn = RI->getFunction() == getAssociatedFunction();
// For PHIs we need to take care of the recurrence explicitly as the value
// might change while we iterate through a loop. For now, we give up if
@@ -1893,14 +1913,37 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) {
DepClassTy::REQUIRED);
if (!CSArgPI)
return false;
- Changed =
- translateAndAddState(A, *CSArgPI, OffsetInfoMap[CurPtr], *CB) |
- Changed;
+ bool IsArgMustAcc = (getUnderlyingObject(CurPtr) == &AssociatedValue);
+ Changed = translateAndAddState(A, *CSArgPI, OffsetInfoMap[CurPtr], *CB,
+ IsArgMustAcc) |
+ Changed;
+ if (!CSArgPI->reachesReturn())
+ return isValidState();
+
+ Function *Callee = CB->getCalledFunction();
+ if (!Callee || Callee->arg_size() <= ArgNo)
+ return false;
+ bool UsedAssumedInformation = false;
+ auto ReturnedValue = A.getAssumedSimplified(
+ IRPosition::returned(*Callee), *this, UsedAssumedInformation,
+ AA::ValueScope::Intraprocedural);
+ auto *ReturnedArg =
+ dyn_cast_or_null<Argument>(ReturnedValue.value_or(nullptr));
+ auto *Arg = Callee->getArg(ArgNo);
+ if (ReturnedArg && Arg != ReturnedArg)
+ return true;
+ bool IsRetMustAcc = IsArgMustAcc && (ReturnedArg == Arg);
+ const auto *CSRetPI = A.getAAFor<AAPointerInfo>(
+ *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
+ if (!CSRetPI)
+ return false;
+ Changed = translateAndAddState(A, *CSRetPI, OffsetInfoMap[CurPtr], *CB,
+ IsRetMustAcc) |
+ Changed;
return isValidState();
}
LLVM_DEBUG(dbgs() << "[AAPointerInfo] Call user not handled " << *CB
<< "\n");
- // TODO: Allow some call uses
return false;
}
@@ -2336,8 +2379,10 @@ struct AANoFreeFloating : AANoFreeImpl {
Follow = true;
return true;
}
- if (isa<StoreInst>(UserI) || isa<LoadInst>(UserI) ||
- isa<ReturnInst>(UserI))
+ if (isa<StoreInst>(UserI) || isa<LoadInst>(UserI))
+ return true;
+
+ if (isa<ReturnInst>(UserI) && getIRPosition().isArgumentPosition())
return true;
// Unknown user.
@@ -12734,7 +12779,7 @@ struct AAAllocationInfoImpl : public AAAllocationInfo {
if (!PI)
return indicatePessimisticFixpoint();
- if (!PI->getState().isValidState())
+ if (!PI->getState().isValidState() || PI->reachesReturn())
return indicatePessimisticFixpoint();
const DataLayout &DL = A.getDataLayout();
diff --git a/llvm/test/Transforms/Attributor/IPConstantProp/pthreads.ll b/llvm/test/Transforms/Attributor/IPConstantProp/pthreads.ll
index 490894d1290231..01a97821140ec6 100644
--- a/llvm/test/Transforms/Attributor/IPConstantProp/pthreads.ll
+++ b/llvm/test/Transforms/Attributor/IPConstantProp/pthreads.ll
@@ -34,13 +34,13 @@ target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
define dso_local i32 @main() {
; TUNIT-LABEL: define {{[^@]+}}@main() {
; TUNIT-NEXT: entry:
-; TUNIT-NEXT: [[ALLOC11:%.*]] = alloca i8, i32 0, align 8
-; TUNIT-NEXT: [[ALLOC22:%.*]] = alloca i8, i32 0, align 8
+; TUNIT-NEXT: [[ALLOC1:%.*]] = alloca i8, align 8
+; TUNIT-NEXT: [[ALLOC2:%.*]] = alloca i8, align 8
; TUNIT-NEXT: [[THREAD:%.*]] = alloca i64, align 8
; TUNIT-NEXT: [[CALL:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @foo, ptr nofree readnone align 4294967296 undef)
; TUNIT-NEXT: [[CALL1:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @bar, ptr noalias nocapture nofree nonnull readnone align 8 dereferenceable(8) undef)
-; TUNIT-NEXT: [[CALL2:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @baz, ptr noalias nocapture nofree noundef nonnull readnone align 8 dereferenceable(1) [[ALLOC11]])
-; TUNIT-NEXT: [[CALL3:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @buz, ptr noalias nofree noundef nonnull readnone align 8 dereferenceable(1) "no-capture-maybe-returned" [[ALLOC22]])
+; TUNIT-NEXT: [[CALL2:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @baz, ptr noalias nocapture nofree noundef nonnull readnone align 8 dereferenceable(1) [[ALLOC1]])
+; TUNIT-NEXT: [[CALL3:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @buz, ptr noalias nofree noundef nonnull readnone align 8 dereferenceable(1) "no-capture-maybe-returned" [[ALLOC2]])
; TUNIT-NEXT: ret i32 0
;
; CGSCC-LABEL: define {{[^@]+}}@main() {
diff --git a/llvm/test/Transforms/Attributor/value-simplify-pointer-info.ll b/llvm/test/Transforms/Attributor/value-simplify-pointer-info.ll
index f7f92e3c87a629..378560cc89cd12 100644
--- a/llvm/test/Transforms/Attributor/value-simplify-pointer-info.ll
+++ b/llvm/test/Transforms/Attributor/value-simplify-pointer-info.ll
@@ -3176,7 +3176,7 @@ define internal i32 @recSimplify2() {
ret i32 %r
}
-; TODO: Verify we do not return 10.
+; Verify we do not return 10.
define i32 @may_access_after_return(i32 noundef %N, i32 noundef %M) {
; TUNIT: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
; TUNIT-LABEL: define {{[^@]+}}@may_access_after_return
@@ -3185,7 +3185,7 @@ define i32 @may_access_after_return(i32 noundef %N, i32 noundef %M) {
; TUNIT-NEXT: [[A:%.*]] = alloca i32, align 4
; TUNIT-NEXT: [[B:%.*]] = alloca i32, align 4
; TUNIT-NEXT: call void @write_both(ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[A]], ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[B]]) #[[ATTR18]]
-; TUNIT-NEXT: ret i32 10
+; TUNIT-NEXT: ret i32 8
;
; CGSCC: Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
; CGSCC-LABEL: define {{[^@]+}}@may_access_after_return
@@ -3237,7 +3237,7 @@ entry:
ret ptr %P
}
-; TODO: Verify we do not return 10.
+; Verify we do not return 10.
define i32 @may_access_after_return_choice(i32 noundef %N, i32 noundef %M, i1 %c) {
; TUNIT: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
; TUNIT-LABEL: define {{[^@]+}}@may_access_after_return_choice
@@ -3248,7 +3248,10 @@ define i32 @may_access_after_return_choice(i32 noundef %N, i32 noundef %M, i1 %c
; TUNIT-NEXT: [[CALL:%.*]] = call nonnull align 4 dereferenceable(4) ptr @passthrough_choice(i1 [[C]], ptr noalias nofree noundef nonnull readnone align 4 dereferenceable(4) "no-capture-maybe-returned" [[A]], ptr noalias nofree noundef nonnull readnone align 4 dereferenceable(4) "no-capture-maybe-returned" [[B]]) #[[ATTR23:[0-9]+]]
; TUNIT-NEXT: [[CALL1:%.*]] = call nonnull align 4 dereferenceable(4) ptr @passthrough_choice(i1 [[C]], ptr noalias nofree noundef nonnull readnone align 4 dereferenceable(4) "no-capture-maybe-returned" [[B]], ptr noalias nofree noundef nonnull readnone align 4 dereferenceable(4) "no-capture-maybe-returned" [[A]]) #[[ATTR23]]
; TUNIT-NEXT: call void @write_both(ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[CALL]], ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[CALL1]]) #[[ATTR18]]
-; TUNIT-NEXT: ret i32 10
+; TUNIT-NEXT: [[TMP0:%.*]] = load i32, ptr [[A]], align 4
+; TUNIT-NEXT: [[TMP1:%.*]] = load i32, ptr [[B]], align 4
+; TUNIT-NEXT: [[ADD:%.*]] = add nsw i32 [[TMP0]], [[TMP1]]
+; TUNIT-NEXT: ret i32 [[ADD]]
;
; CGSCC: Function Attrs: mustprogress nofree nosync nounwind willreturn
; CGSCC-LABEL: define {{[^@]+}}@may_access_after_return_choice
@@ -3289,7 +3292,7 @@ entry:
ret ptr %R
}
-; TODO: Verify we do not return 10.
+; Verify we do not return 10.
define i32 @may_access_after_return_no_choice1(i32 noundef %N, i32 noundef %M) {
; TUNIT: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
; TUNIT-LABEL: define {{[^@]+}}@may_access_after_return_no_choice1
@@ -3298,7 +3301,7 @@ define i32 @may_access_after_return_no_choice1(i32 noundef %N, i32 noundef %M) {
; TUNIT-NEXT: [[A:%.*]] = alloca i32, align 4
; TUNIT-NEXT: [[B:%.*]] = alloca i32, align 4
; TUNIT-NEXT: call void @write_both(ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[A]], ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[B]]) #[[ATTR18]]
-; TUNIT-NEXT: ret i32 10
+; TUNIT-NEXT: ret i32 8
;
; CGSCC: Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
; CGSCC-LABEL: define {{[^@]+}}@may_access_after_return_no_choice1
@@ -3324,7 +3327,7 @@ entry:
ret i32 %add
}
-; TODO: Verify we do not return 10.
+; Verify we do not return 10.
define i32 @may_access_after_return_no_choice2(i32 noundef %N, i32 noundef %M) {
; TUNIT: Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
; TUNIT-LABEL: define {{[^@]+}}@may_access_after_return_no_choice2
@@ -3333,7 +3336,7 @@ define i32 @may_access_after_return_no_choice2(i32 noundef %N, i32 noundef %M) {
; TUNIT-NEXT: [[A:%.*]] = alloca i32, align 4
; TUNIT-NEXT: [[B:%.*]] = alloca i32, align 4
; TUNIT-NEXT: call void @write_both(ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[B]], ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[A]]) #[[ATTR18]]
-; TUNIT-NEXT: ret i32 10
+; TUNIT-NEXT: ret i32 8
;
; CGSCC: Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
; CGSCC-LABEL: define {{[^@]+}}@may_access_after_return_no_choice2
|
Instead of visiting call sites in Attribute::checkForAllUses, we now keep track of returns in AAPointerInfo and use the call site return information as required. This way, the user of AAPointerInfo(CallSite)Argument can determine if the call return should be visited. We do not collect them as "may accesses" in the AAPointerInfo(CallSite)Argument itself in case a return user is found.
bb7251d
to
a211d88
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks reasonable
…lvm#107479)" makes 532.exa_sph 20x slower This reverts commit 56a0334. Change-Id: I796a33f51ce312ca550a5c48c0391de05be302a7
) Instead of visiting call sites in Attribute::checkForAllUses, we now keep track of returns in AAPointerInfo and use the call site return information as required. This way, the user of AAPointerInfo(CallSite)Argument can determine if the call return should be visited. We do not collect them as "may accesses" in the AAPointerInfo(CallSite)Argument itself in case a return user is found. Change-Id: Icb636ff4a4cf2bb53dd04fce20e17f53c2f5aee1
Instead of visiting call sites in Attribute::checkForAllUses, we now
keep track of returns in AAPointerInfo and use the call site return
information as required. This way, the user of
AAPointerInfo(CallSite)Argument can determine if the call return should
be visited. We do not collect them as "may accesses" in the
AAPointerInfo(CallSite)Argument itself in case a return user is found.