Skip to content

Commit 56a0334

Browse files
authored
[Attributor] Keep track of reached returns in AAPointerInfo (#107479)
Instead of visiting call sites in Attribute::checkForAllUses, we now keep track of returns in AAPointerInfo and use the call site return information as required. This way, the user of AAPointerInfo(CallSite)Argument can determine if the call return should be visited. We do not collect them as "may accesses" in the AAPointerInfo(CallSite)Argument itself in case a return user is found.
1 parent 2bcab9b commit 56a0334

File tree

5 files changed

+57
-42
lines changed

5 files changed

+57
-42
lines changed

llvm/include/llvm/Transforms/IPO/Attributor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6119,6 +6119,7 @@ struct AAPointerInfo : public AbstractAttribute {
61196119
virtual const_bin_iterator begin() const = 0;
61206120
virtual const_bin_iterator end() const = 0;
61216121
virtual int64_t numOffsetBins() const = 0;
6122+
virtual bool reachesReturn() const = 0;
61226123

61236124
/// Call \p CB on all accesses that might interfere with \p Range and return
61246125
/// true if all such accesses were known and the callback returned true for

llvm/lib/Transforms/IPO/Attributor.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,22 +1852,6 @@ bool Attributor::checkForAllUses(
18521852

18531853
User &Usr = *U->getUser();
18541854
AddUsers(Usr, /* OldUse */ nullptr);
1855-
1856-
auto *RI = dyn_cast<ReturnInst>(&Usr);
1857-
if (!RI)
1858-
continue;
1859-
1860-
Function &F = *RI->getFunction();
1861-
auto CallSitePred = [&](AbstractCallSite ACS) {
1862-
return AddUsers(*ACS.getInstruction(), U);
1863-
};
1864-
if (!checkForAllCallSites(CallSitePred, F, /* RequireAllCallSites */ true,
1865-
&QueryingAA, UsedAssumedInformation)) {
1866-
LLVM_DEBUG(dbgs() << "[Attributor] Could not follow return instruction "
1867-
"to all call sites: "
1868-
<< *RI << "\n");
1869-
return false;
1870-
}
18711855
}
18721856

18731857
return true;

llvm/lib/Transforms/IPO/AttributorAttributes.cpp

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,7 @@ struct AA::PointerInfo::State : public AbstractState {
827827
AccessList = R.AccessList;
828828
OffsetBins = R.OffsetBins;
829829
RemoteIMap = R.RemoteIMap;
830+
ReachesReturn = R.ReachesReturn;
830831
return *this;
831832
}
832833

@@ -837,6 +838,7 @@ struct AA::PointerInfo::State : public AbstractState {
837838
std::swap(AccessList, R.AccessList);
838839
std::swap(OffsetBins, R.OffsetBins);
839840
std::swap(RemoteIMap, R.RemoteIMap);
841+
std::swap(ReachesReturn, R.ReachesReturn);
840842
return *this;
841843
}
842844

@@ -878,11 +880,16 @@ struct AA::PointerInfo::State : public AbstractState {
878880
AAPointerInfo::OffsetBinsTy OffsetBins;
879881
DenseMap<const Instruction *, SmallVector<unsigned>> RemoteIMap;
880882

883+
/// Flag to determine if the underlying pointer is reaching a return statement
884+
/// in the associated function or not. Returns in other functions cause
885+
/// invalidation.
886+
bool ReachesReturn = false;
887+
881888
/// See AAPointerInfo::forallInterferingAccesses.
882889
bool forallInterferingAccesses(
883890
AA::RangeTy Range,
884891
function_ref<bool(const AAPointerInfo::Access &, bool)> CB) const {
885-
if (!isValidState())
892+
if (!isValidState() || ReachesReturn)
886893
return false;
887894

888895
for (const auto &It : OffsetBins) {
@@ -904,7 +911,7 @@ struct AA::PointerInfo::State : public AbstractState {
904911
Instruction &I,
905912
function_ref<bool(const AAPointerInfo::Access &, bool)> CB,
906913
AA::RangeTy &Range) const {
907-
if (!isValidState())
914+
if (!isValidState() || ReachesReturn)
908915
return false;
909916

910917
auto LocalList = RemoteIMap.find(&I);
@@ -1071,7 +1078,8 @@ struct AAPointerInfoImpl
10711078
return std::string("PointerInfo ") +
10721079
(isValidState() ? (std::string("#") +
10731080
std::to_string(OffsetBins.size()) + " bins")
1074-
: "<invalid>");
1081+
: "<invalid>") +
1082+
(ReachesReturn ? " (returned)" : "");
10751083
}
10761084

10771085
/// See AbstractAttribute::manifest(...).
@@ -1084,6 +1092,7 @@ struct AAPointerInfoImpl
10841092
virtual int64_t numOffsetBins() const override {
10851093
return State::numOffsetBins();
10861094
}
1095+
virtual bool reachesReturn() const override { return ReachesReturn; }
10871096

10881097
bool forallInterferingAccesses(
10891098
AA::RangeTy Range,
@@ -1373,6 +1382,7 @@ struct AAPointerInfoImpl
13731382

13741383
const auto &OtherAAImpl = static_cast<const AAPointerInfoImpl &>(OtherAA);
13751384
bool IsByval = OtherAAImpl.getAssociatedArgument()->hasByValAttr();
1385+
ReachesReturn = OtherAAImpl.ReachesReturn;
13761386

13771387
// Combine the accesses bin by bin.
13781388
ChangeStatus Changed = ChangeStatus::UNCHANGED;
@@ -1666,8 +1676,13 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) {
16661676
}
16671677
if (isa<PtrToIntInst>(Usr))
16681678
return false;
1669-
if (isa<CastInst>(Usr) || isa<SelectInst>(Usr) || isa<ReturnInst>(Usr))
1679+
if (isa<CastInst>(Usr) || isa<SelectInst>(Usr))
16701680
return HandlePassthroughUser(Usr, CurPtr, Follow);
1681+
// Returns are allowed if they are in the associated functions. Users can
1682+
// then check the call site return. Returns from other functions can't be
1683+
// tracked and are cause for invalidation.
1684+
if (auto *RI = dyn_cast<ReturnInst>(Usr))
1685+
return ReachesReturn = RI->getFunction() == getAssociatedFunction();
16711686

16721687
// For PHIs we need to take care of the recurrence explicitly as the value
16731688
// might change while we iterate through a loop. For now, we give up if
@@ -1898,15 +1913,37 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) {
18981913
DepClassTy::REQUIRED);
18991914
if (!CSArgPI)
19001915
return false;
1901-
bool IsMustAcc = (getUnderlyingObject(CurPtr) == &AssociatedValue);
1916+
bool IsArgMustAcc = (getUnderlyingObject(CurPtr) == &AssociatedValue);
19021917
Changed = translateAndAddState(A, *CSArgPI, OffsetInfoMap[CurPtr], *CB,
1903-
IsMustAcc) |
1918+
IsArgMustAcc) |
1919+
Changed;
1920+
if (!CSArgPI->reachesReturn())
1921+
return isValidState();
1922+
1923+
Function *Callee = CB->getCalledFunction();
1924+
if (!Callee || Callee->arg_size() <= ArgNo)
1925+
return false;
1926+
bool UsedAssumedInformation = false;
1927+
auto ReturnedValue = A.getAssumedSimplified(
1928+
IRPosition::returned(*Callee), *this, UsedAssumedInformation,
1929+
AA::ValueScope::Intraprocedural);
1930+
auto *ReturnedArg =
1931+
dyn_cast_or_null<Argument>(ReturnedValue.value_or(nullptr));
1932+
auto *Arg = Callee->getArg(ArgNo);
1933+
if (ReturnedArg && Arg != ReturnedArg)
1934+
return true;
1935+
bool IsRetMustAcc = IsArgMustAcc && (ReturnedArg == Arg);
1936+
const auto *CSRetPI = A.getAAFor<AAPointerInfo>(
1937+
*this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
1938+
if (!CSRetPI)
1939+
return false;
1940+
Changed = translateAndAddState(A, *CSRetPI, OffsetInfoMap[CurPtr], *CB,
1941+
IsRetMustAcc) |
19041942
Changed;
19051943
return isValidState();
19061944
}
19071945
LLVM_DEBUG(dbgs() << "[AAPointerInfo] Call user not handled " << *CB
19081946
<< "\n");
1909-
// TODO: Allow some call uses
19101947
return false;
19111948
}
19121949

@@ -2342,8 +2379,10 @@ struct AANoFreeFloating : AANoFreeImpl {
23422379
Follow = true;
23432380
return true;
23442381
}
2345-
if (isa<StoreInst>(UserI) || isa<LoadInst>(UserI) ||
2346-
isa<ReturnInst>(UserI))
2382+
if (isa<StoreInst>(UserI) || isa<LoadInst>(UserI))
2383+
return true;
2384+
2385+
if (isa<ReturnInst>(UserI) && getIRPosition().isArgumentPosition())
23472386
return true;
23482387

23492388
// Unknown user.
@@ -12740,7 +12779,7 @@ struct AAAllocationInfoImpl : public AAAllocationInfo {
1274012779
if (!PI)
1274112780
return indicatePessimisticFixpoint();
1274212781

12743-
if (!PI->getState().isValidState())
12782+
if (!PI->getState().isValidState() || PI->reachesReturn())
1274412783
return indicatePessimisticFixpoint();
1274512784

1274612785
const DataLayout &DL = A.getDataLayout();

llvm/test/Transforms/Attributor/IPConstantProp/pthreads.ll

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
3434
define dso_local i32 @main() {
3535
; TUNIT-LABEL: define {{[^@]+}}@main() {
3636
; TUNIT-NEXT: entry:
37-
; TUNIT-NEXT: [[ALLOC11:%.*]] = alloca i8, i32 0, align 8
38-
; TUNIT-NEXT: [[ALLOC22:%.*]] = alloca i8, i32 0, align 8
37+
; TUNIT-NEXT: [[ALLOC1:%.*]] = alloca i8, align 8
38+
; TUNIT-NEXT: [[ALLOC2:%.*]] = alloca i8, align 8
3939
; TUNIT-NEXT: [[THREAD:%.*]] = alloca i64, align 8
4040
; TUNIT-NEXT: [[CALL:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @foo, ptr nofree readnone align 4294967296 undef)
4141
; TUNIT-NEXT: [[CALL1:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @bar, ptr noalias nocapture nofree nonnull readnone align 8 dereferenceable(8) undef)
42-
; TUNIT-NEXT: [[CALL2:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @baz, ptr noalias nocapture nofree noundef nonnull readnone align 8 dereferenceable(1) [[ALLOC11]])
43-
; TUNIT-NEXT: [[CALL3:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @buz, ptr noalias nofree noundef nonnull readnone align 8 dereferenceable(1) "no-capture-maybe-returned" [[ALLOC22]])
42+
; TUNIT-NEXT: [[CALL2:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @baz, ptr noalias nocapture nofree noundef nonnull readnone align 8 dereferenceable(1) [[ALLOC1]])
43+
; TUNIT-NEXT: [[CALL3:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @buz, ptr noalias nofree noundef nonnull readnone align 8 dereferenceable(1) "no-capture-maybe-returned" [[ALLOC2]])
4444
; TUNIT-NEXT: ret i32 0
4545
;
4646
; CGSCC-LABEL: define {{[^@]+}}@main() {

llvm/test/Transforms/Attributor/value-simplify-pointer-info.ll

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3185,10 +3185,7 @@ define i32 @may_access_after_return(i32 noundef %N, i32 noundef %M) {
31853185
; TUNIT-NEXT: [[A:%.*]] = alloca i32, align 4
31863186
; TUNIT-NEXT: [[B:%.*]] = alloca i32, align 4
31873187
; TUNIT-NEXT: call void @write_both(ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[A]], ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[B]]) #[[ATTR18]]
3188-
; TUNIT-NEXT: [[TMP0:%.*]] = load i32, ptr [[A]], align 4
3189-
; TUNIT-NEXT: [[TMP1:%.*]] = load i32, ptr [[B]], align 4
3190-
; TUNIT-NEXT: [[ADD:%.*]] = add nsw i32 [[TMP0]], [[TMP1]]
3191-
; TUNIT-NEXT: ret i32 [[ADD]]
3188+
; TUNIT-NEXT: ret i32 8
31923189
;
31933190
; CGSCC: Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
31943191
; CGSCC-LABEL: define {{[^@]+}}@may_access_after_return
@@ -3304,10 +3301,7 @@ define i32 @may_access_after_return_no_choice1(i32 noundef %N, i32 noundef %M) {
33043301
; TUNIT-NEXT: [[A:%.*]] = alloca i32, align 4
33053302
; TUNIT-NEXT: [[B:%.*]] = alloca i32, align 4
33063303
; TUNIT-NEXT: call void @write_both(ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[A]], ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[B]]) #[[ATTR18]]
3307-
; TUNIT-NEXT: [[TMP0:%.*]] = load i32, ptr [[A]], align 4
3308-
; TUNIT-NEXT: [[TMP1:%.*]] = load i32, ptr [[B]], align 4
3309-
; TUNIT-NEXT: [[ADD:%.*]] = add nsw i32 [[TMP0]], [[TMP1]]
3310-
; TUNIT-NEXT: ret i32 [[ADD]]
3304+
; TUNIT-NEXT: ret i32 8
33113305
;
33123306
; CGSCC: Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
33133307
; CGSCC-LABEL: define {{[^@]+}}@may_access_after_return_no_choice1
@@ -3342,10 +3336,7 @@ define i32 @may_access_after_return_no_choice2(i32 noundef %N, i32 noundef %M) {
33423336
; TUNIT-NEXT: [[A:%.*]] = alloca i32, align 4
33433337
; TUNIT-NEXT: [[B:%.*]] = alloca i32, align 4
33443338
; TUNIT-NEXT: call void @write_both(ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[B]], ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[A]]) #[[ATTR18]]
3345-
; TUNIT-NEXT: [[TMP0:%.*]] = load i32, ptr [[A]], align 4
3346-
; TUNIT-NEXT: [[TMP1:%.*]] = load i32, ptr [[B]], align 4
3347-
; TUNIT-NEXT: [[ADD:%.*]] = add nsw i32 [[TMP0]], [[TMP1]]
3348-
; TUNIT-NEXT: ret i32 [[ADD]]
3339+
; TUNIT-NEXT: ret i32 8
33493340
;
33503341
; CGSCC: Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
33513342
; CGSCC-LABEL: define {{[^@]+}}@may_access_after_return_no_choice2

0 commit comments

Comments
 (0)