Skip to content

Commit 4c0a0f7

Browse files
authored
[SandboxVectorizer][NFCI] Fix use of possibly-uninitialized scalar. (#122201)
The `EraseCallbackID` field is not always initialized in the ctor for SeedCollector; if not, it will be used uninitialized by its dtor. This could potentially lead to the erasure of a random callback, leading to a bug. Fixed by making `CallbackID` an opaque type, which is always default-initialized to an invalid ID.
1 parent a8e1135 commit 4c0a0f7

File tree

2 files changed

+50
-10
lines changed

2 files changed

+50
-10
lines changed

llvm/include/llvm/SandboxIR/Context.h

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
#include <cstdint>
2020

21-
namespace llvm::sandboxir {
21+
namespace llvm {
22+
namespace sandboxir {
2223

2324
class Argument;
2425
class BBIterator;
@@ -37,10 +38,28 @@ class Context {
3738
using MoveInstrCallback =
3839
std::function<void(Instruction *, const BBIterator &)>;
3940

40-
/// An ID for a registered callback. Used for deregistration. Using a 64-bit
41-
/// integer so we don't have to worry about the unlikely case of overflowing
42-
/// a 32-bit counter.
43-
using CallbackID = uint64_t;
41+
/// An ID for a registered callback. Used for deregistration. A dedicated type
42+
/// is employed so as to keep IDs opaque to the end user; only Context should
43+
/// deal with its underlying representation.
44+
class CallbackID {
45+
public:
46+
// Uses a 64-bit integer so we don't have to worry about the unlikely case
47+
// of overflowing a 32-bit counter.
48+
using ValTy = uint64_t;
49+
static constexpr const ValTy InvalidVal = 0;
50+
51+
private:
52+
// Default initialization results in an invalid ID.
53+
ValTy Val = InvalidVal;
54+
explicit CallbackID(ValTy Val) : Val{Val} {
55+
assert(Val != InvalidVal && "newly-created ID is invalid!");
56+
}
57+
58+
public:
59+
CallbackID() = default;
60+
friend class Context;
61+
friend struct DenseMapInfo<CallbackID>;
62+
};
4463

4564
protected:
4665
LLVMContext &LLVMCtx;
@@ -83,7 +102,7 @@ class Context {
83102
/// A counter used for assigning callback IDs during registration. The same
84103
/// counter is used for all kinds of callbacks so we can detect mismatched
85104
/// registration/deregistration.
86-
CallbackID NextCallbackID = 0;
105+
CallbackID::ValTy NextCallbackID = 1;
87106

88107
/// Remove \p V from the maps and returns the unique_ptr.
89108
std::unique_ptr<Value> detachLLVMValue(llvm::Value *V);
@@ -263,6 +282,27 @@ class Context {
263282
// TODO: Add callbacks for instructions inserted/removed if needed.
264283
};
265284

266-
} // namespace llvm::sandboxir
285+
} // namespace sandboxir
286+
287+
// DenseMap info for CallbackIDs
288+
template <> struct DenseMapInfo<sandboxir::Context::CallbackID> {
289+
using CallbackID = sandboxir::Context::CallbackID;
290+
using ReprInfo = DenseMapInfo<CallbackID::ValTy>;
291+
292+
static CallbackID getEmptyKey() {
293+
return CallbackID{ReprInfo::getEmptyKey()};
294+
}
295+
static CallbackID getTombstoneKey() {
296+
return CallbackID{ReprInfo::getTombstoneKey()};
297+
}
298+
static unsigned getHashValue(const CallbackID &ID) {
299+
return ReprInfo::getHashValue(ID.Val);
300+
}
301+
static bool isEqual(const CallbackID &LHS, const CallbackID &RHS) {
302+
return ReprInfo::isEqual(LHS.Val, RHS.Val);
303+
}
304+
};
305+
306+
} // namespace llvm
267307

268308
#endif // LLVM_SANDBOXIR_CONTEXT_H

llvm/lib/SandboxIR/Context.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
686686
Context::CallbackID Context::registerEraseInstrCallback(EraseInstrCallback CB) {
687687
assert(EraseInstrCallbacks.size() <= MaxRegisteredCallbacks &&
688688
"EraseInstrCallbacks size limit exceeded");
689-
CallbackID ID = NextCallbackID++;
689+
CallbackID ID{NextCallbackID++};
690690
EraseInstrCallbacks[ID] = CB;
691691
return ID;
692692
}
@@ -700,7 +700,7 @@ Context::CallbackID
700700
Context::registerCreateInstrCallback(CreateInstrCallback CB) {
701701
assert(CreateInstrCallbacks.size() <= MaxRegisteredCallbacks &&
702702
"CreateInstrCallbacks size limit exceeded");
703-
CallbackID ID = NextCallbackID++;
703+
CallbackID ID{NextCallbackID++};
704704
CreateInstrCallbacks[ID] = CB;
705705
return ID;
706706
}
@@ -713,7 +713,7 @@ void Context::unregisterCreateInstrCallback(CallbackID ID) {
713713
Context::CallbackID Context::registerMoveInstrCallback(MoveInstrCallback CB) {
714714
assert(MoveInstrCallbacks.size() <= MaxRegisteredCallbacks &&
715715
"MoveInstrCallbacks size limit exceeded");
716-
CallbackID ID = NextCallbackID++;
716+
CallbackID ID{NextCallbackID++};
717717
MoveInstrCallbacks[ID] = CB;
718718
return ID;
719719
}

0 commit comments

Comments
 (0)