Skip to content

Commit fa2f2ae

Browse files
committed
[BOLT] Gadget scanner: use more appropriate types (NFC)
* use more flexible `const ArrayRef<T>` and `StringRef` types instead of `const std::vector<T> &` and `const std::string &`, correspondingly, for function arguments * return plain `const SrcState &` instead of `ErrorOr<const SrcState &>` from `SrcSafetyAnalysis::getStateBefore`, as absent state is not handled gracefully by any caller
1 parent f5455c6 commit fa2f2ae

File tree

2 files changed

+19
-28
lines changed

2 files changed

+19
-28
lines changed

bolt/include/bolt/Passes/PAuthGadgetScanner.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include "bolt/Core/BinaryContext.h"
1313
#include "bolt/Core/BinaryFunction.h"
1414
#include "bolt/Passes/BinaryPasses.h"
15-
#include "llvm/ADT/SmallSet.h"
1615
#include "llvm/Support/raw_ostream.h"
1716
#include <memory>
1817

@@ -199,9 +198,6 @@ raw_ostream &operator<<(raw_ostream &OS, const MCInstReference &);
199198

200199
namespace PAuthGadgetScanner {
201200

202-
class SrcSafetyAnalysis;
203-
struct SrcState;
204-
205201
/// Description of a gadget kind that can be detected. Intended to be
206202
/// statically allocated to be attached to reports by reference.
207203
class GadgetKind {
@@ -210,7 +206,7 @@ class GadgetKind {
210206
public:
211207
GadgetKind(const char *Description) : Description(Description) {}
212208

213-
const StringRef getDescription() const { return Description; }
209+
StringRef getDescription() const { return Description; }
214210
};
215211

216212
/// Base report located at some instruction, without any additional information.
@@ -261,7 +257,7 @@ struct GadgetReport : public Report {
261257
/// Report with a free-form message attached.
262258
struct GenericReport : public Report {
263259
std::string Text;
264-
GenericReport(MCInstReference Location, const std::string &Text)
260+
GenericReport(MCInstReference Location, StringRef Text)
265261
: Report(Location), Text(Text) {}
266262
virtual void generateReport(raw_ostream &OS,
267263
const BinaryContext &BC) const override;

bolt/lib/Passes/PAuthGadgetScanner.cpp

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,14 @@ class TrackedRegisters {
9191
const std::vector<MCPhysReg> Registers;
9292
std::vector<uint16_t> RegToIndexMapping;
9393

94-
static size_t getMappingSize(const std::vector<MCPhysReg> &RegsToTrack) {
94+
static size_t getMappingSize(const ArrayRef<MCPhysReg> RegsToTrack) {
9595
if (RegsToTrack.empty())
9696
return 0;
9797
return 1 + *llvm::max_element(RegsToTrack);
9898
}
9999

100100
public:
101-
TrackedRegisters(const std::vector<MCPhysReg> &RegsToTrack)
101+
TrackedRegisters(const ArrayRef<MCPhysReg> RegsToTrack)
102102
: Registers(RegsToTrack),
103103
RegToIndexMapping(getMappingSize(RegsToTrack), NoIndex) {
104104
for (unsigned I = 0; I < RegsToTrack.size(); ++I)
@@ -234,7 +234,7 @@ struct SrcState {
234234

235235
static void printLastInsts(
236236
raw_ostream &OS,
237-
const std::vector<SmallPtrSet<const MCInst *, 4>> &LastInstWritingReg) {
237+
const ArrayRef<SmallPtrSet<const MCInst *, 4>> LastInstWritingReg) {
238238
OS << "Insts: ";
239239
for (unsigned I = 0; I < LastInstWritingReg.size(); ++I) {
240240
auto &Set = LastInstWritingReg[I];
@@ -295,19 +295,18 @@ void SrcStatePrinter::print(raw_ostream &OS, const SrcState &S) const {
295295
class SrcSafetyAnalysis {
296296
public:
297297
SrcSafetyAnalysis(BinaryFunction &BF,
298-
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
298+
const ArrayRef<MCPhysReg> RegsToTrackInstsFor)
299299
: BC(BF.getBinaryContext()), NumRegs(BC.MRI->getNumRegs()),
300300
RegsToTrackInstsFor(RegsToTrackInstsFor) {}
301301

302302
virtual ~SrcSafetyAnalysis() {}
303303

304304
static std::shared_ptr<SrcSafetyAnalysis>
305305
create(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
306-
const std::vector<MCPhysReg> &RegsToTrackInstsFor);
306+
const ArrayRef<MCPhysReg> RegsToTrackInstsFor);
307307

308308
virtual void run() = 0;
309-
virtual ErrorOr<const SrcState &>
310-
getStateBefore(const MCInst &Inst) const = 0;
309+
virtual const SrcState &getStateBefore(const MCInst &Inst) const = 0;
311310

312311
protected:
313312
BinaryContext &BC;
@@ -347,7 +346,7 @@ class SrcSafetyAnalysis {
347346
}
348347

349348
BitVector getClobberedRegs(const MCInst &Point) const {
350-
BitVector Clobbered(NumRegs, false);
349+
BitVector Clobbered(NumRegs);
351350
// Assume a call can clobber all registers, including callee-saved
352351
// registers. There's a good chance that callee-saved registers will be
353352
// saved on the stack at some point during execution of the callee.
@@ -408,8 +407,7 @@ class SrcSafetyAnalysis {
408407

409408
// FirstCheckerInst should belong to the same basic block, meaning
410409
// it was deterministically processed a few steps before this instruction.
411-
const SrcState &StateBeforeChecker =
412-
getStateBefore(*FirstCheckerInst).get();
410+
const SrcState &StateBeforeChecker = getStateBefore(*FirstCheckerInst);
413411
if (StateBeforeChecker.SafeToDerefRegs[CheckedReg])
414412
Regs.push_back(CheckedReg);
415413
}
@@ -522,10 +520,7 @@ class SrcSafetyAnalysis {
522520
const ArrayRef<MCPhysReg> UsedDirtyRegs) const {
523521
if (RegsToTrackInstsFor.empty())
524522
return {};
525-
auto MaybeState = getStateBefore(Inst);
526-
if (!MaybeState)
527-
llvm_unreachable("Expected state to be present");
528-
const SrcState &S = *MaybeState;
523+
const SrcState &S = getStateBefore(Inst);
529524
// Due to aliasing registers, multiple registers may have been tracked.
530525
std::set<const MCInst *> LastWritingInsts;
531526
for (MCPhysReg TrackedReg : UsedDirtyRegs) {
@@ -536,7 +531,7 @@ class SrcSafetyAnalysis {
536531
for (const MCInst *Inst : LastWritingInsts) {
537532
MCInstReference Ref = MCInstReference::get(Inst, BF);
538533
assert(Ref && "Expected Inst to be found");
539-
Result.push_back(MCInstReference(Ref));
534+
Result.push_back(Ref);
540535
}
541536
return Result;
542537
}
@@ -556,11 +551,11 @@ class DataflowSrcSafetyAnalysis
556551
public:
557552
DataflowSrcSafetyAnalysis(BinaryFunction &BF,
558553
MCPlusBuilder::AllocatorIdTy AllocId,
559-
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
554+
const ArrayRef<MCPhysReg> RegsToTrackInstsFor)
560555
: SrcSafetyAnalysis(BF, RegsToTrackInstsFor), DFParent(BF, AllocId) {}
561556

562-
ErrorOr<const SrcState &> getStateBefore(const MCInst &Inst) const override {
563-
return DFParent::getStateBefore(Inst);
557+
const SrcState &getStateBefore(const MCInst &Inst) const override {
558+
return DFParent::getStateBefore(Inst).get();
564559
}
565560

566561
void run() override {
@@ -669,7 +664,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis {
669664
public:
670665
CFGUnawareSrcSafetyAnalysis(BinaryFunction &BF,
671666
MCPlusBuilder::AllocatorIdTy AllocId,
672-
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
667+
const ArrayRef<MCPhysReg> RegsToTrackInstsFor)
673668
: SrcSafetyAnalysis(BF, RegsToTrackInstsFor), BF(BF), AllocId(AllocId) {
674669
StateAnnotationIndex =
675670
BC.MIB->getOrCreateAnnotationIndex("CFGUnawareSrcSafetyAnalysis");
@@ -703,7 +698,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis {
703698
}
704699
}
705700

706-
ErrorOr<const SrcState &> getStateBefore(const MCInst &Inst) const override {
701+
const SrcState &getStateBefore(const MCInst &Inst) const override {
707702
return BC.MIB->getAnnotationAs<SrcState>(Inst, StateAnnotationIndex);
708703
}
709704

@@ -713,7 +708,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis {
713708
std::shared_ptr<SrcSafetyAnalysis>
714709
SrcSafetyAnalysis::create(BinaryFunction &BF,
715710
MCPlusBuilder::AllocatorIdTy AllocId,
716-
const std::vector<MCPhysReg> &RegsToTrackInstsFor) {
711+
const ArrayRef<MCPhysReg> RegsToTrackInstsFor) {
717712
if (BF.hasCFG())
718713
return std::make_shared<DataflowSrcSafetyAnalysis>(BF, AllocId,
719714
RegsToTrackInstsFor);
@@ -820,7 +815,7 @@ Analysis::findGadgets(BinaryFunction &BF,
820815

821816
BinaryContext &BC = BF.getBinaryContext();
822817
iterateOverInstrs(BF, [&](MCInstReference Inst) {
823-
const SrcState &S = *Analysis->getStateBefore(Inst);
818+
const SrcState &S = Analysis->getStateBefore(Inst);
824819

825820
// If non-empty state was never propagated from the entry basic block
826821
// to Inst, assume it to be unreachable and report a warning.

0 commit comments

Comments
 (0)