Skip to content

Commit 3581e9b

Browse files
authored
[NFC][IR2Vec] Refactoring for Stateless Embedding Computation (llvm#141811)
Currently, users have to invoke two APIs: `computeEmbeddings()` followed by getters to access the embeddings. This PR refactors the code to reduce this *stateful* access of APIs. Users can now directly invoke getters; Internally, getters would compute the embeddings.
1 parent 259fe01 commit 3581e9b

File tree

3 files changed

+54
-34
lines changed

3 files changed

+54
-34
lines changed

llvm/docs/MLGO.rst

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -490,24 +490,21 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
490490
std::unique_ptr<ir2vec::Embedder> Emb = std::move(*EmbOrErr);
491491
492492
3. **Compute and Access Embeddings**:
493-
Call ``computeEmbeddings()`` on the embedder instance to compute the
494-
embeddings. Then the embeddings can be accessed using different getter
495-
methods. Currently, ``Embedder`` can generate embeddings at three levels:
496-
Instructions, Basic Blocks, and Functions.
493+
Call ``getFunctionVector()`` to get the embedding for the function.
497494

498-
.. code-block:: c++
495+
.. code-block:: c++
499496

500-
Emb->computeEmbeddings();
501497
const ir2vec::Embedding &FuncVector = Emb->getFunctionVector();
502-
const ir2vec::InstEmbeddingsMap &InstVecMap = Emb->getInstVecMap();
503-
const ir2vec::BBEmbeddingsMap &BBVecMap = Emb->getBBVecMap();
504-
505-
// Example: Iterate over instruction embeddings
506-
for (const auto &Entry : InstVecMap) {
507-
const Instruction *Inst = Entry.getFirst();
508-
const ir2vec::Embedding &InstEmbedding = Entry.getSecond();
509-
// Use Inst and InstEmbedding
510-
}
498+
499+
Currently, ``Embedder`` can generate embeddings at three levels: Instructions,
500+
Basic Blocks, and Functions. Appropriate getters are provided to access the
501+
embeddings at these levels.
502+
503+
.. note::
504+
505+
The validity of ``Embedder`` instance (and the embeddings it generates) is
506+
tied to the function it is associated with remains unchanged. If the function
507+
is modified, the embeddings may become stale and should be recomputed accordingly.
511508

512509
4. **Working with Embeddings:**
513510
Embeddings are represented as ``std::vector<double>``. These

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,17 @@ class Embedder {
8080

8181
// Utility maps - these are used to store the vector representations of
8282
// instructions, basic blocks and functions.
83-
Embedding FuncVector;
84-
BBEmbeddingsMap BBVecMap;
85-
InstEmbeddingsMap InstVecMap;
83+
mutable Embedding FuncVector;
84+
mutable BBEmbeddingsMap BBVecMap;
85+
mutable InstEmbeddingsMap InstVecMap;
8686

8787
Embedder(const Function &F, const Vocab &Vocabulary, unsigned Dimension);
8888

89+
/// Helper function to compute embeddings. It generates embeddings for all
90+
/// the instructions and basic blocks in the function F. Logic of computing
91+
/// the embeddings is specific to the kind of embeddings being computed.
92+
virtual void computeEmbeddings() const = 0;
93+
8994
/// Lookup vocabulary for a given Key. If the key is not found, it returns a
9095
/// zero vector.
9196
Embedding lookupVocab(const std::string &Key) const;
@@ -100,25 +105,24 @@ class Embedder {
100105
public:
101106
virtual ~Embedder() = default;
102107

103-
/// Top level function to compute embeddings. It generates embeddings for all
104-
/// the instructions and basic blocks in the function F. Logic of computing
105-
/// the embeddings is specific to the kind of embeddings being computed.
106-
virtual void computeEmbeddings() = 0;
107-
108108
/// Factory method to create an Embedder object.
109109
static Expected<std::unique_ptr<Embedder>> create(IR2VecKind Mode,
110110
const Function &F,
111111
const Vocab &Vocabulary,
112112
unsigned Dimension);
113113

114-
/// Returns a map containing instructions and the corresponding embeddings.
115-
const InstEmbeddingsMap &getInstVecMap() const { return InstVecMap; }
114+
/// Returns a map containing instructions and the corresponding embeddings for
115+
/// the function F if it has been computed. If not, it computes the embeddings
116+
/// for the function and returns the map.
117+
const InstEmbeddingsMap &getInstVecMap() const;
116118

117-
/// Returns a map containing basic block and the corresponding embeddings.
118-
const BBEmbeddingsMap &getBBVecMap() const { return BBVecMap; }
119+
/// Returns a map containing basic block and the corresponding embeddings for
120+
/// the function F if it has been computed. If not, it computes the embeddings
121+
/// for the function and returns the map.
122+
const BBEmbeddingsMap &getBBVecMap() const;
119123

120-
/// Returns the embedding for the current function.
121-
const Embedding &getFunctionVector() const { return FuncVector; }
124+
/// Computes and returns the embedding for the current function.
125+
const Embedding &getFunctionVector() const;
122126
};
123127

124128
/// Class for computing the Symbolic embeddings of IR2Vec.
@@ -127,21 +131,22 @@ class Embedder {
127131
class SymbolicEmbedder : public Embedder {
128132
private:
129133
/// Utility function to compute the embedding for a given basic block.
130-
Embedding computeBB2Vec(const BasicBlock &BB);
134+
Embedding computeBB2Vec(const BasicBlock &BB) const;
131135

132136
/// Utility function to compute the embedding for a given type.
133137
Embedding getTypeEmbedding(const Type *Ty) const;
134138

135139
/// Utility function to compute the embedding for a given operand.
136140
Embedding getOperandEmbedding(const Value *Op) const;
137141

142+
void computeEmbeddings() const override;
143+
138144
public:
139145
SymbolicEmbedder(const Function &F, const Vocab &Vocabulary,
140146
unsigned Dimension)
141147
: Embedder(F, Vocabulary, Dimension) {
142148
FuncVector = Embedding(Dimension, 0);
143149
}
144-
void computeEmbeddings() override;
145150
};
146151

147152
} // namespace ir2vec

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,25 @@ Embedding Embedder::lookupVocab(const std::string &Key) const {
103103
return Vec;
104104
}
105105

106+
const InstEmbeddingsMap &Embedder::getInstVecMap() const {
107+
if (InstVecMap.empty())
108+
computeEmbeddings();
109+
return InstVecMap;
110+
}
111+
112+
const BBEmbeddingsMap &Embedder::getBBVecMap() const {
113+
if (BBVecMap.empty())
114+
computeEmbeddings();
115+
return BBVecMap;
116+
}
117+
118+
const Embedding &Embedder::getFunctionVector() const {
119+
// Currently, we always (re)compute the embeddings for the function.
120+
// This is cheaper than caching the vector.
121+
computeEmbeddings();
122+
return FuncVector;
123+
}
124+
106125
#define RETURN_LOOKUP_IF(CONDITION, KEY_STR) \
107126
if (CONDITION) \
108127
return lookupVocab(KEY_STR);
@@ -132,7 +151,7 @@ Embedding SymbolicEmbedder::getOperandEmbedding(const Value *Op) const {
132151

133152
#undef RETURN_LOOKUP_IF
134153

135-
void SymbolicEmbedder::computeEmbeddings() {
154+
void SymbolicEmbedder::computeEmbeddings() const {
136155
if (F.isDeclaration())
137156
return;
138157
for (const auto &BB : F) {
@@ -142,7 +161,7 @@ void SymbolicEmbedder::computeEmbeddings() {
142161
}
143162
}
144163

145-
Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) {
164+
Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) const {
146165
Embedding BBVector(Dimension, 0);
147166

148167
for (const auto &I : BB) {
@@ -271,7 +290,6 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
271290
}
272291

273292
std::unique_ptr<Embedder> Emb = std::move(*EmbOrErr);
274-
Emb->computeEmbeddings();
275293

276294
OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
277295
OS << "Function vector: ";

0 commit comments

Comments
 (0)