Skip to content

Commit 5c7dca7

Browse files
committed
[NFC][GS/StackClash] use LatticeT to implement lattice semantics for Reg2MaxOffset
, rather than hand-rolling it poorly using std::optional.
1 parent 95119f4 commit 5c7dca7

File tree

1 file changed

+73
-41
lines changed

1 file changed

+73
-41
lines changed

bolt/lib/Passes/StackClashAnalysis.cpp

Lines changed: 73 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ bool addToMaxMap(SmallDenseMap<MCPhysReg, uint64_t, 1> &M, MCPhysReg R,
7272

7373
template <typename T, auto MergeValLambda> class LatticeT {
7474
private:
75-
enum LValType { _Bottom, _Top, Value } LValType;
75+
enum LValType { _Bottom, Value, _Top } LValType;
7676
T V;
7777
LatticeT(enum LValType ValType, T Val) : LValType(ValType), V(Val) {}
7878
static LatticeT _TopV; //(_Top, T());
@@ -111,10 +111,31 @@ template <typename T, auto MergeValLambda> class LatticeT {
111111
return LValType == RHS.LValType && V == RHS.V;
112112
}
113113
bool operator!=(const LatticeT &RHS) const { return !(*this == RHS); }
114+
bool operator<(const LatticeT &RHS) const {
115+
if (LValType < RHS.LValType)
116+
return true;
117+
if (LValType > RHS.LValType)
118+
return false;
119+
assert(LValType == RHS.LValType);
120+
if (LValType == Value)
121+
return V < RHS.V;
122+
else
123+
return false;
124+
}
125+
126+
bool hasVal() const { return *this != Bottom() && *this != Top(); }
114127
const T &getVal() const {
115-
assert(*this != Bottom() && *this != Top());
128+
assert(hasVal());
116129
return V;
117130
}
131+
T &getVal() {
132+
assert(hasVal());
133+
return V;
134+
}
135+
T *operator->() { return &getVal(); }
136+
const T *operator->() const { return &getVal(); }
137+
T &operator*() { return getVal(); }
138+
const T &operator*() const { return getVal(); }
118139
LatticeT &doOnVal(std::function<const T &(T &, const T &)> f, const T &V2) {
119140
assert(*this != Bottom());
120141
if (*this == Top())
@@ -149,6 +170,45 @@ MaxOffsetT &operator+=(MaxOffsetT &O1, const int64_t O2) {
149170
return O1.doOnVal(AddOffset, O2);
150171
}
151172

173+
using Reg2MaxOffsetValT = SmallDenseMap<MCPhysReg, MaxOffsetT, 2>;
174+
bool Reg2MaxOffsetMergeVal(Reg2MaxOffsetValT &v1, const Reg2MaxOffsetValT &v2) {
175+
SmallVector<MCPhysReg, 1> RegMaxValuesToRemove;
176+
for (auto Reg2MaxValue : v1) {
177+
const MCPhysReg R(Reg2MaxValue.first);
178+
if (auto v2Reg2MaxValue = v2.find(R); v2Reg2MaxValue == v2.end())
179+
RegMaxValuesToRemove.push_back(R);
180+
else
181+
Reg2MaxValue.second =
182+
std::max(Reg2MaxValue.second, v2Reg2MaxValue->second);
183+
// FIXME: this should be a "confluence" - similar
184+
// to MaxOffsetT? To avoid near infinite loops?
185+
}
186+
for (MCPhysReg R : RegMaxValuesToRemove)
187+
v1.erase(R);
188+
return true;
189+
}
190+
191+
void print_reg(raw_ostream &OS, MCPhysReg Reg, const BinaryContext *BC) {
192+
if (!BC)
193+
OS << "R" << Reg;
194+
else {
195+
RegStatePrinter RegStatePrinter(*BC);
196+
BitVector BV(BC->MRI->getNumRegs(), false);
197+
BV.set(Reg);
198+
RegStatePrinter.print(OS, BV);
199+
}
200+
}
201+
202+
raw_ostream &operator<<(raw_ostream &OS, const Reg2MaxOffsetValT &M) {
203+
for (auto Reg2Value : M) {
204+
print_reg(OS, Reg2Value.first, nullptr);
205+
OS << ":" << Reg2Value.second << ",";
206+
}
207+
return OS;
208+
}
209+
210+
using Reg2MaxOffsetT = LatticeT<Reg2MaxOffsetValT, Reg2MaxOffsetMergeVal>;
211+
152212
struct State {
153213
// Store the maximum possible offset to which the stack extends
154214
// beyond the furthest probe seen.
@@ -172,7 +232,7 @@ struct State {
172232
/// This is only tracked in Basic Blocks that are known to be reachable
173233
/// from an entry block. For blocks not (yet) known to be reachable from
174234
/// an entry block, the optional does not contain a value.
175-
std::optional<SmallDenseMap<MCPhysReg, MaxOffsetT, 2>> Reg2MaxOffset;
235+
Reg2MaxOffsetT Reg2MaxOffset;
176236
// FIXME: It seems that conceptually it does not make sense to
177237
// track wheterh the SP value is currently at a fixed offset from
178238
// the value it was at function entry.
@@ -223,24 +283,7 @@ struct State {
223283
else if (*SPFixedOffsetFromOrig != *StateIn.SPFixedOffsetFromOrig)
224284
SPFixedOffsetFromOrig.reset();
225285

226-
if (StateIn.Reg2MaxOffset && Reg2MaxOffset) {
227-
SmallVector<MCPhysReg, 2> RToRemove;
228-
for (auto R2MaxOff : *Reg2MaxOffset) {
229-
const MCPhysReg R = R2MaxOff.first;
230-
if (auto SIn_R2MaxOff = StateIn.Reg2MaxOffset->find(R);
231-
SIn_R2MaxOff == StateIn.Reg2MaxOffset->end())
232-
RToRemove.push_back(R);
233-
else {
234-
MaxOffsetT MaxOff1 = R2MaxOff.second;
235-
MaxOffsetT MaxOff2 = SIn_R2MaxOff->second;
236-
MaxOff1 &= MaxOff2;
237-
}
238-
for (auto R : RToRemove)
239-
Reg2MaxOffset->erase(R);
240-
}
241-
} else if (StateIn.Reg2MaxOffset && !Reg2MaxOffset) {
242-
Reg2MaxOffset = StateIn.Reg2MaxOffset;
243-
}
286+
Reg2MaxOffset &= StateIn.Reg2MaxOffset;
244287

245288
for (auto I : StateIn.LastStackGrowingInsts)
246289
LastStackGrowingInsts.insert(I);
@@ -256,17 +299,6 @@ struct State {
256299
bool operator!=(const State &RHS) const { return !((*this) == RHS); }
257300
};
258301

259-
void print_reg(raw_ostream &OS, MCPhysReg Reg, const BinaryContext *BC) {
260-
if (!BC)
261-
OS << "R" << Reg;
262-
else {
263-
RegStatePrinter RegStatePrinter(*BC);
264-
BitVector BV(BC->MRI->getNumRegs(), false);
265-
BV.set(Reg);
266-
RegStatePrinter.print(OS, BV);
267-
}
268-
}
269-
270302
template <class T, unsigned N>
271303
void PrintRegMap(raw_ostream &OS, const SmallDenseMap<MCPhysReg, T, N> &M,
272304
const BinaryContext *BC = nullptr) {
@@ -290,12 +322,12 @@ raw_ostream &print_state(raw_ostream &OS, const State &S,
290322
OS << "),";
291323
OS << "SPFixedOffsetFromOrig:" << S.SPFixedOffsetFromOrig << ",";
292324
OS << "Reg2MaxOffset:";
293-
if (S.Reg2MaxOffset) {
325+
if (S.Reg2MaxOffset.hasVal()) {
294326
OS << "(";
295-
PrintRegMap(OS, *S.Reg2MaxOffset, BC);
327+
PrintRegMap(OS, S.Reg2MaxOffset.getVal(), BC);
296328
OS << ")";
297329
} else
298-
OS << "None";
330+
OS << S.Reg2MaxOffset;
299331
OS << ",";
300332
OS << "LastStackGrowingInsts(" << S.LastStackGrowingInsts.size() << ")> ";
301333
return OS;
@@ -363,8 +395,8 @@ bool checkNonConstSPOffsetChange(const BinaryContext &BC, BinaryFunction &BF,
363395
// assert(!OC.IsPreIndexOffsetChange || IsStackAccess);
364396
if (Next)
365397
assert(*Next->MaxOffsetSinceLastProbe >= 0);
366-
} else if (Cur.Reg2MaxOffset && Cur.Reg2MaxOffset->contains(OC.FromReg) &&
367-
OC.OffsetChange) {
398+
} else if (Cur.Reg2MaxOffset.hasVal() &&
399+
Cur.Reg2MaxOffset->contains(OC.FromReg) && OC.OffsetChange) {
368400
IsNonConstantSPOffsetChange = false;
369401
const MaxOffsetT MaxOffset =
370402
Cur.Reg2MaxOffset->find(OC.FromReg)->second;
@@ -389,7 +421,7 @@ bool checkNonConstSPOffsetChange(const BinaryContext &BC, BinaryFunction &BF,
389421
uint64_t Mask = 0;
390422
if (MCPhysReg FromReg, ToReg;
391423
BC.MIB->isMaskLowerBitsInReg(Point, FromReg, ToReg, Mask) &&
392-
Cur.Reg2MaxOffset && Cur.Reg2MaxOffset->contains(FromReg)) {
424+
Cur.Reg2MaxOffset.hasVal() && Cur.Reg2MaxOffset->contains(FromReg)) {
393425
// handle SP-aligning patterns like
394426
// sub x9, sp, #0x1d0
395427
// and sp, x9, #0xffffffffffffff80
@@ -433,7 +465,7 @@ class StackClashDFAnalysis
433465
State getStartingStateAtBB(const BinaryBasicBlock &BB) {
434466
State Next;
435467
if (BB.isEntryPoint())
436-
Next.Reg2MaxOffset = SmallDenseMap<MCPhysReg, MaxOffsetT, 2>();
468+
Next.Reg2MaxOffset = Reg2MaxOffsetValT();
437469
return Next;
438470
}
439471

@@ -541,7 +573,7 @@ class StackClashDFAnalysis
541573
MCPhysReg FixedOffsetRegJustSet = BC.MIB->getNoRegister();
542574
if (auto OC = BC.MIB->getOffsetChange(Point, Cur.RegConstValues,
543575
Cur.RegMaxValues))
544-
if (Next.Reg2MaxOffset && OC.OffsetChange) {
576+
if (Next.Reg2MaxOffset.hasVal() && OC.OffsetChange) {
545577
int64_t Offset = *OC.OffsetChange;
546578
if (OC.FromReg == SP) {
547579
MaxOffsetT &MaxOffset = (*Next.Reg2MaxOffset)[OC.ToReg] =
@@ -555,7 +587,7 @@ class StackClashDFAnalysis
555587
FixedOffsetRegJustSet = OC.ToReg;
556588
}
557589
}
558-
if (Next.Reg2MaxOffset)
590+
if (Next.Reg2MaxOffset.hasVal())
559591
for (const MCOperand &Operand : BC.MIB->defOperands(Point)) {
560592
if (Operand.getReg() != FixedOffsetRegJustSet) {
561593
Next.Reg2MaxOffset->erase(Operand.getReg());

0 commit comments

Comments
 (0)