Skip to content

Commit 58dd8a4

Browse files
committed
Apply initializes attribute in DSE
1 parent c811ea4 commit 58dd8a4

File tree

2 files changed

+229
-42
lines changed

2 files changed

+229
-42
lines changed

llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp

Lines changed: 184 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include "llvm/IR/Argument.h"
5353
#include "llvm/IR/BasicBlock.h"
5454
#include "llvm/IR/Constant.h"
55+
#include "llvm/IR/ConstantRangeList.h"
5556
#include "llvm/IR/Constants.h"
5657
#include "llvm/IR/DataLayout.h"
5758
#include "llvm/IR/DebugInfo.h"
@@ -164,6 +165,10 @@ static cl::opt<bool>
164165
OptimizeMemorySSA("dse-optimize-memoryssa", cl::init(true), cl::Hidden,
165166
cl::desc("Allow DSE to optimize memory accesses."));
166167

168+
static cl::opt<bool> EnableInitializesImprovement(
169+
"enable-dse-initializes-attr-improvement", cl::init(false), cl::Hidden,
170+
cl::desc("Enable the initializes attr improvement in DSE"));
171+
167172
//===----------------------------------------------------------------------===//
168173
// Helper functions
169174
//===----------------------------------------------------------------------===//
@@ -809,8 +814,10 @@ bool canSkipDef(MemoryDef *D, bool DefVisibleToCaller) {
809814
// A memory location wrapper that represents a MemoryLocation, `MemLoc`,
810815
// defined by `MemDef`.
811816
struct MemoryLocationWrapper {
812-
MemoryLocationWrapper(MemoryLocation MemLoc, MemoryDef *MemDef)
813-
: MemLoc(MemLoc), MemDef(MemDef) {
817+
MemoryLocationWrapper(MemoryLocation MemLoc, MemoryDef *MemDef,
818+
bool DefByInitializesAttr)
819+
: MemLoc(MemLoc), MemDef(MemDef),
820+
DefByInitializesAttr(DefByInitializesAttr) {
814821
assert(MemLoc.Ptr && "MemLoc should be not null");
815822
UnderlyingObject = getUnderlyingObject(MemLoc.Ptr);
816823
DefInst = MemDef->getMemoryInst();
@@ -820,20 +827,121 @@ struct MemoryLocationWrapper {
820827
const Value *UnderlyingObject;
821828
MemoryDef *MemDef;
822829
Instruction *DefInst;
830+
bool DefByInitializesAttr = false;
823831
};
824832

825833
// A memory def wrapper that represents a MemoryDef and the MemoryLocation(s)
826834
// defined by this MemoryDef.
827835
struct MemoryDefWrapper {
828-
MemoryDefWrapper(MemoryDef *MemDef, std::optional<MemoryLocation> MemLoc) {
836+
MemoryDefWrapper(
837+
MemoryDef *MemDef,
838+
const SmallVectorImpl<std::pair<MemoryLocation, bool>> &MemLocations) {
829839
DefInst = MemDef->getMemoryInst();
830-
if (MemLoc.has_value())
831-
DefinedLocation = MemoryLocationWrapper(*MemLoc, MemDef);
840+
for (auto &[MemLoc, DefByInitializesAttr] : MemLocations)
841+
DefinedLocations.push_back(
842+
MemoryLocationWrapper(MemLoc, MemDef, DefByInitializesAttr));
832843
}
833844
Instruction *DefInst;
834-
std::optional<MemoryLocationWrapper> DefinedLocation = std::nullopt;
845+
SmallVector<MemoryLocationWrapper, 1> DefinedLocations;
846+
};
847+
848+
bool HasInitializesAttr(Instruction *I) {
849+
CallBase *CB = dyn_cast<CallBase>(I);
850+
if (!CB)
851+
return false;
852+
853+
for (size_t Idx = 0; Idx < CB->arg_size(); Idx++)
854+
if (CB->paramHasAttr(Idx, Attribute::Initializes))
855+
return true;
856+
return false;
857+
}
858+
859+
struct ArgumentInitInfo {
860+
size_t Idx = -1;
861+
ConstantRangeList Inits;
862+
bool HasDeadOnUnwindAttr = false;
863+
bool FuncHasNoUnwindAttr = false;
835864
};
836865

866+
ConstantRangeList
867+
GetMergedInitAttr(const SmallVectorImpl<ArgumentInitInfo> &Args) {
868+
if (Args.empty())
869+
return {};
870+
871+
// To address unwind, the function should have nounwind attribute or the
872+
// arguments have dead_on_unwind attribute. Otherwise, return empty.
873+
for (const auto &Arg : Args) {
874+
if (!Arg.FuncHasNoUnwindAttr && !Arg.HasDeadOnUnwindAttr)
875+
return {};
876+
if (Arg.Inits.empty())
877+
return {};
878+
}
879+
880+
if (Args.size() == 1)
881+
return Args[0].Inits;
882+
883+
ConstantRangeList MergedIntervals = Args[0].Inits;
884+
for (size_t i = 1; i < Args.size(); i++) {
885+
MergedIntervals = MergedIntervals.intersectWith(Args[i].Inits);
886+
}
887+
return MergedIntervals;
888+
}
889+
890+
// Return the locations wrote by the initializes attribute.
891+
// Note that this function considers:
892+
// 1. Unwind edge: apply "initializes" attribute only if the callee has
893+
// "nounwind" attribute or the argument has "dead_on_unwind" attribute.
894+
// 2. Argument alias: for aliasing arguments, the "initializes" attribute is
895+
// the merged range list of their "initializes" attributes.
896+
SmallVector<MemoryLocation, 1>
897+
GetInitializesArgMemLoc(const Instruction *I, BatchAAResults &BatchAA) {
898+
const CallBase *CB = dyn_cast<CallBase>(I);
899+
if (!CB)
900+
return {};
901+
902+
bool HasNoUnwindAttr = CB->hasFnAttr(Attribute::NoUnwind);
903+
SmallMapVector<Value *, SmallVector<ArgumentInitInfo, 2>, 2> Arguments;
904+
for (size_t Idx = 0; Idx < CB->arg_size(); Idx++) {
905+
bool HasDeadOnUnwindAttr = CB->paramHasAttr(Idx, Attribute::DeadOnUnwind);
906+
907+
ConstantRangeList Inits;
908+
if (CB->paramHasAttr(Idx, Attribute::Initializes))
909+
Inits = CB->getParamAttr(Idx, Attribute::Initializes)
910+
.getValueAsConstantRangeList();
911+
912+
ArgumentInitInfo InitInfo{Idx, Inits, HasDeadOnUnwindAttr, HasNoUnwindAttr};
913+
Value *CurArg = CB->getArgOperand(Idx);
914+
bool FoundAliasing = false;
915+
for (auto &[Arg, AliasList] : Arguments) {
916+
if (BatchAA.isMustAlias(Arg, CurArg)) {
917+
FoundAliasing = true;
918+
AliasList.push_back(InitInfo);
919+
}
920+
}
921+
if (!FoundAliasing)
922+
Arguments[CurArg] = {InitInfo};
923+
}
924+
925+
SmallVector<MemoryLocation, 1> Locations;
926+
for (const auto &[_, Args] : Arguments) {
927+
auto MergedInitAttr = GetMergedInitAttr(Args);
928+
if (MergedInitAttr.empty())
929+
continue;
930+
931+
for (const auto &Arg : Args) {
932+
for (const auto &Range : MergedInitAttr) {
933+
int64_t Start = Range.getLower().getSExtValue();
934+
int64_t End = Range.getUpper().getSExtValue();
935+
if (Start == 0)
936+
Locations.push_back(MemoryLocation(CB->getArgOperand(Arg.Idx),
937+
LocationSize::precise(End - Start),
938+
CB->getAAMetadata()));
939+
}
940+
}
941+
}
942+
return Locations;
943+
}
944+
837945
struct DSEState {
838946
Function &F;
839947
AliasAnalysis &AA;
@@ -911,7 +1019,8 @@ struct DSEState {
9111019

9121020
auto *MD = dyn_cast_or_null<MemoryDef>(MA);
9131021
if (MD && MemDefs.size() < MemorySSADefsPerBlockLimit &&
914-
(getLocForWrite(&I) || isMemTerminatorInst(&I)))
1022+
(getLocForWrite(&I) || isMemTerminatorInst(&I) ||
1023+
HasInitializesAttr(&I)))
9151024
MemDefs.push_back(MD);
9161025
}
9171026
}
@@ -1147,13 +1256,24 @@ struct DSEState {
11471256
return MemoryLocation::getOrNone(I);
11481257
}
11491258

1150-
std::optional<MemoryLocation> getLocForInst(Instruction *I) {
1259+
SmallVector<std::pair<MemoryLocation, bool>, 1>
1260+
getLocForInst(Instruction *I, bool consider_initializes_attr) {
1261+
SmallVector<std::pair<MemoryLocation, bool>, 1> Locations;
11511262
if (isMemTerminatorInst(I)) {
1152-
if (auto Loc = getLocForTerminator(I)) {
1153-
return Loc->first;
1263+
if (auto Loc = getLocForTerminator(I))
1264+
Locations.push_back(std::make_pair(Loc->first, false));
1265+
return Locations;
1266+
}
1267+
1268+
if (auto Loc = getLocForWrite(I))
1269+
Locations.push_back(std::make_pair(*Loc, false));
1270+
1271+
if (consider_initializes_attr) {
1272+
for (auto &MemLoc : GetInitializesArgMemLoc(I, BatchAA)) {
1273+
Locations.push_back(std::make_pair(MemLoc, true));
11541274
}
11551275
}
1156-
return getLocForWrite(I);
1276+
return Locations;
11571277
}
11581278

11591279
/// Assuming this instruction has a dead analyzable write, can we delete
@@ -1365,7 +1485,8 @@ struct DSEState {
13651485
getDomMemoryDef(MemoryDef *KillingDef, MemoryAccess *StartAccess,
13661486
const MemoryLocation &KillingLoc, const Value *KillingUndObj,
13671487
unsigned &ScanLimit, unsigned &WalkerStepLimit,
1368-
bool IsMemTerm, unsigned &PartialLimit) {
1488+
bool IsMemTerm, unsigned &PartialLimit,
1489+
bool IsInitializesAttrMemLoc) {
13691490
if (ScanLimit == 0 || WalkerStepLimit == 0) {
13701491
LLVM_DEBUG(dbgs() << "\n ... hit scan limit\n");
13711492
return std::nullopt;
@@ -1602,7 +1723,19 @@ struct DSEState {
16021723

16031724
// Uses which may read the original MemoryDef mean we cannot eliminate the
16041725
// original MD. Stop walk.
1605-
if (isReadClobber(MaybeDeadLoc, UseInst)) {
1726+
// If KillingDef is a CallInst with "initializes" attribute, the reads in
1727+
// Callee would be dominated by initializations, so this should be safe.
1728+
bool IsKillingDefFromInitAttr = false;
1729+
if (IsInitializesAttrMemLoc) {
1730+
if (KillingI == UseInst &&
1731+
KillingUndObj == getUnderlyingObject(MaybeDeadLoc.Ptr)) {
1732+
IsKillingDefFromInitAttr = true;
1733+
// Note that, we don't need to check aliasing arguments here since
1734+
// aliasing has been considered at the begining.
1735+
}
1736+
}
1737+
1738+
if (isReadClobber(MaybeDeadLoc, UseInst) && !IsKillingDefFromInitAttr) {
16061739
LLVM_DEBUG(dbgs() << " ... found read clobber\n");
16071740
return std::nullopt;
16081741
}
@@ -2207,7 +2340,8 @@ DSEState::eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper) {
22072340
std::optional<MemoryAccess *> MaybeDeadAccess = getDomMemoryDef(
22082341
KillingLocWrapper.MemDef, Current, KillingLocWrapper.MemLoc,
22092342
KillingLocWrapper.UnderlyingObject, ScanLimit, WalkerStepLimit,
2210-
isMemTerminatorInst(KillingLocWrapper.DefInst), PartialLimit);
2343+
isMemTerminatorInst(KillingLocWrapper.DefInst), PartialLimit,
2344+
KillingLocWrapper.DefByInitializesAttr);
22112345

22122346
if (!MaybeDeadAccess) {
22132347
LLVM_DEBUG(dbgs() << " finished walk\n");
@@ -2232,8 +2366,11 @@ DSEState::eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper) {
22322366
}
22332367
MemoryDefWrapper DeadDefWrapper(
22342368
cast<MemoryDef>(DeadAccess),
2235-
getLocForInst(cast<MemoryDef>(DeadAccess)->getMemoryInst()));
2236-
MemoryLocationWrapper &DeadLocWrapper = *DeadDefWrapper.DefinedLocation;
2369+
getLocForInst(cast<MemoryDef>(DeadAccess)->getMemoryInst(),
2370+
/*consider_initializes_attr=*/false));
2371+
assert(DeadDefWrapper.DefinedLocations.size() == 1);
2372+
MemoryLocationWrapper &DeadLocWrapper =
2373+
DeadDefWrapper.DefinedLocations.front();
22372374
LLVM_DEBUG(dbgs() << " (" << *DeadLocWrapper.DefInst << ")\n");
22382375
ToCheck.insert(DeadLocWrapper.MemDef->getDefiningAccess());
22392376
NumGetDomMemoryDefPassed++;
@@ -2311,37 +2448,41 @@ DSEState::eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper) {
23112448
}
23122449

23132450
bool DSEState::eliminateDeadDefs(const MemoryDefWrapper &KillingDefWrapper) {
2314-
if (!KillingDefWrapper.DefinedLocation.has_value()) {
2451+
if (KillingDefWrapper.DefinedLocations.empty()) {
23152452
LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for "
23162453
<< *KillingDefWrapper.DefInst << "\n");
23172454
return false;
23182455
}
23192456

2320-
auto &KillingLocWrapper = *KillingDefWrapper.DefinedLocation;
2321-
LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by "
2322-
<< *KillingLocWrapper.MemDef << " ("
2323-
<< *KillingLocWrapper.DefInst << ")\n");
2324-
auto [Changed, DeletedKillingLoc] = eliminateDeadDefs(KillingLocWrapper);
2325-
2326-
// Check if the store is a no-op.
2327-
if (!DeletedKillingLoc && storeIsNoop(KillingLocWrapper.MemDef,
2328-
KillingLocWrapper.UnderlyingObject)) {
2329-
LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n DEAD: "
2330-
<< *KillingLocWrapper.DefInst << '\n');
2331-
deleteDeadInstruction(KillingLocWrapper.DefInst);
2332-
NumRedundantStores++;
2333-
return true;
2334-
}
2335-
// Can we form a calloc from a memset/malloc pair?
2336-
if (!DeletedKillingLoc &&
2337-
tryFoldIntoCalloc(KillingLocWrapper.MemDef,
2338-
KillingLocWrapper.UnderlyingObject)) {
2339-
LLVM_DEBUG(dbgs() << "DSE: Remove memset after forming calloc:\n"
2340-
<< " DEAD: " << *KillingLocWrapper.DefInst << '\n');
2341-
deleteDeadInstruction(KillingLocWrapper.DefInst);
2342-
return true;
2457+
bool MadeChange = false;
2458+
for (auto &KillingLocWrapper : KillingDefWrapper.DefinedLocations) {
2459+
LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by "
2460+
<< *KillingLocWrapper.MemDef << " ("
2461+
<< *KillingLocWrapper.DefInst << ")\n");
2462+
auto [Changed, DeletedKillingLoc] = eliminateDeadDefs(KillingLocWrapper);
2463+
2464+
// Check if the store is a no-op.
2465+
if (!DeletedKillingLoc && storeIsNoop(KillingLocWrapper.MemDef,
2466+
KillingLocWrapper.UnderlyingObject)) {
2467+
LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n DEAD: "
2468+
<< *KillingLocWrapper.DefInst << '\n');
2469+
deleteDeadInstruction(KillingLocWrapper.DefInst);
2470+
NumRedundantStores++;
2471+
MadeChange = true;
2472+
continue;
2473+
}
2474+
// Can we form a calloc from a memset/malloc pair?
2475+
if (!DeletedKillingLoc &&
2476+
tryFoldIntoCalloc(KillingLocWrapper.MemDef,
2477+
KillingLocWrapper.UnderlyingObject)) {
2478+
LLVM_DEBUG(dbgs() << "DSE: Remove memset after forming calloc:\n"
2479+
<< " DEAD: " << *KillingLocWrapper.DefInst << '\n');
2480+
deleteDeadInstruction(KillingLocWrapper.DefInst);
2481+
MadeChange = true;
2482+
continue;
2483+
}
23432484
}
2344-
return Changed;
2485+
return MadeChange;
23452486
}
23462487

23472488
static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
@@ -2357,7 +2498,8 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
23572498
continue;
23582499

23592500
MemoryDefWrapper KillingDefWrapper(
2360-
KillingDef, State.getLocForInst(KillingDef->getMemoryInst()));
2501+
KillingDef, State.getLocForInst(KillingDef->getMemoryInst(),
2502+
EnableInitializesImprovement));
23612503
MadeChange |= State.eliminateDeadDefs(KillingDefWrapper);
23622504
}
23632505

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt < %s -passes=function-attrs,dse -enable-dse-initializes-attr-improvement -S | FileCheck %s
3+
4+
; Function Attrs: mustprogress nounwind uwtable
5+
define void @write_only_arg(ptr nocapture noundef writeonly initializes((0, 2)) %ptr) {
6+
store i16 100, ptr %ptr
7+
ret void
8+
}
9+
10+
; Function Attrs: mustprogress nounwind uwtable memory(none argmem: readwrite)
11+
define i16 @write_then_read_arg(ptr nocapture noundef initializes((0, 2)) %ptr) {
12+
store i16 10, ptr %ptr
13+
%l = load i16, ptr %ptr
14+
ret i16 %l
15+
}
16+
17+
; Function Attrs: mustprogress nounwind uwtable
18+
define i16 @write_only_caller() {
19+
; CHECK-LABEL: @write_only_caller(
20+
; CHECK-NEXT: %ptr = alloca i16, align 2
21+
; CHECK-NEXT: call void @write_only_arg(ptr %ptr)
22+
; CHECK-NEXT: %l = load i16, ptr %ptr
23+
; CHECK-NEXT: ret i16 %l
24+
;
25+
%ptr = alloca i16
26+
store i16 0, ptr %ptr
27+
call void @write_only_arg(ptr %ptr)
28+
%l = load i16, ptr %ptr
29+
ret i16 %l
30+
}
31+
32+
; Function Attrs: mustprogress nounwind uwtable
33+
define i16 @write_then_read_caller() {
34+
; CHECK-LABEL: @write_then_read_caller(
35+
; CHECK-NEXT: %ptr = alloca i16, align 2
36+
; CHECK-NEXT: %call = call i16 @write_then_read_arg(ptr %ptr)
37+
; CHECK-NEXT: %l = load i16, ptr %ptr
38+
; CHECK-NEXT: ret i16 %l
39+
;
40+
%ptr = alloca i16
41+
store i16 0, ptr %ptr
42+
%call = call i16 @write_then_read_arg(ptr %ptr)
43+
%l = load i16, ptr %ptr
44+
ret i16 %l
45+
}

0 commit comments

Comments
 (0)