-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[NFC][IR2Vec] Refactoring for Stateless Embedding Computation #141811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
4023ab7
to
5eab1c1
Compare
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-mlgo Author: S. VenkataKeerthy (svkeerthy) ChangesCurrently, users have to invoke two APIs: Full diff: https://github.com/llvm/llvm-project/pull/141811.diff 3 Files Affected:
diff --git a/llvm/docs/MLGO.rst b/llvm/docs/MLGO.rst
index fa4b02cb11be7..549d8369d648d 100644
--- a/llvm/docs/MLGO.rst
+++ b/llvm/docs/MLGO.rst
@@ -490,24 +490,21 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
std::unique_ptr<ir2vec::Embedder> Emb = std::move(*EmbOrErr);
3. **Compute and Access Embeddings**:
- Call ``computeEmbeddings()`` on the embedder instance to compute the
- embeddings. Then the embeddings can be accessed using different getter
- methods. Currently, ``Embedder`` can generate embeddings at three levels:
- Instructions, Basic Blocks, and Functions.
+ Call ``getFunctionVector()`` to get the embedding for the function.
- .. code-block:: c++
+ .. code-block:: c++
- Emb->computeEmbeddings();
const ir2vec::Embedding &FuncVector = Emb->getFunctionVector();
- const ir2vec::InstEmbeddingsMap &InstVecMap = Emb->getInstVecMap();
- const ir2vec::BBEmbeddingsMap &BBVecMap = Emb->getBBVecMap();
-
- // Example: Iterate over instruction embeddings
- for (const auto &Entry : InstVecMap) {
- const Instruction *Inst = Entry.getFirst();
- const ir2vec::Embedding &InstEmbedding = Entry.getSecond();
- // Use Inst and InstEmbedding
- }
+
+ Currently, ``Embedder`` can generate embeddings at three levels: Instructions,
+ Basic Blocks, and Functions. Appropriate getters are provided to access the
+ embeddings at these levels.
+
+ .. note::
+
+ The validity of ``Embedder`` instance (and the embeddings it generates) is
+ tied to the function it is associated with remains unchanged. If the function
+ is modified, the embeddings may become stale and should be recomputed accordingly.
4. **Working with Embeddings:**
Embeddings are represented as ``std::vector<double>``. These
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 3063040093402..43c95c5e89aed 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -80,12 +80,17 @@ class Embedder {
// Utility maps - these are used to store the vector representations of
// instructions, basic blocks and functions.
- Embedding FuncVector;
- BBEmbeddingsMap BBVecMap;
- InstEmbeddingsMap InstVecMap;
+ mutable Embedding FuncVector;
+ mutable BBEmbeddingsMap BBVecMap;
+ mutable InstEmbeddingsMap InstVecMap;
Embedder(const Function &F, const Vocab &Vocabulary, unsigned Dimension);
+ /// Helper function to compute embeddings. It generates embeddings for all
+ /// the instructions and basic blocks in the function F. Logic of computing
+ /// the embeddings is specific to the kind of embeddings being computed.
+ virtual void computeEmbeddings() const = 0;
+
/// Lookup vocabulary for a given Key. If the key is not found, it returns a
/// zero vector.
Embedding lookupVocab(const std::string &Key) const;
@@ -100,25 +105,24 @@ class Embedder {
public:
virtual ~Embedder() = default;
- /// Top level function to compute embeddings. It generates embeddings for all
- /// the instructions and basic blocks in the function F. Logic of computing
- /// the embeddings is specific to the kind of embeddings being computed.
- virtual void computeEmbeddings() = 0;
-
/// Factory method to create an Embedder object.
static Expected<std::unique_ptr<Embedder>> create(IR2VecKind Mode,
const Function &F,
const Vocab &Vocabulary,
unsigned Dimension);
- /// Returns a map containing instructions and the corresponding embeddings.
- const InstEmbeddingsMap &getInstVecMap() const { return InstVecMap; }
+ /// Returns a map containing instructions and the corresponding embeddings for
+ /// the function F if it has been computed. If not, it computes the embeddings
+ /// for the function and returns the map.
+ const InstEmbeddingsMap &getInstVecMap() const;
- /// Returns a map containing basic block and the corresponding embeddings.
- const BBEmbeddingsMap &getBBVecMap() const { return BBVecMap; }
+ /// Returns a map containing basic block and the corresponding embeddings for
+ /// the function F if it has been computed. If not, it computes the embeddings
+ /// for the function and returns the map.
+ const BBEmbeddingsMap &getBBVecMap() const;
- /// Returns the embedding for the current function.
- const Embedding &getFunctionVector() const { return FuncVector; }
+ /// Computes and returns the embedding for the current function.
+ const Embedding &getFunctionVector() const;
};
/// Class for computing the Symbolic embeddings of IR2Vec.
@@ -127,7 +131,7 @@ class Embedder {
class SymbolicEmbedder : public Embedder {
private:
/// Utility function to compute the embedding for a given basic block.
- Embedding computeBB2Vec(const BasicBlock &BB);
+ Embedding computeBB2Vec(const BasicBlock &BB) const;
/// Utility function to compute the embedding for a given type.
Embedding getTypeEmbedding(const Type *Ty) const;
@@ -135,13 +139,14 @@ class SymbolicEmbedder : public Embedder {
/// Utility function to compute the embedding for a given operand.
Embedding getOperandEmbedding(const Value *Op) const;
+ void computeEmbeddings() const override;
+
public:
SymbolicEmbedder(const Function &F, const Vocab &Vocabulary,
unsigned Dimension)
: Embedder(F, Vocabulary, Dimension) {
FuncVector = Embedding(Dimension, 0);
}
- void computeEmbeddings() override;
};
} // namespace ir2vec
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index cc419c84e9881..5f3114dcdeeaa 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -103,6 +103,25 @@ Embedding Embedder::lookupVocab(const std::string &Key) const {
return Vec;
}
+const InstEmbeddingsMap &Embedder::getInstVecMap() const {
+ if (InstVecMap.empty())
+ computeEmbeddings();
+ return InstVecMap;
+}
+
+const BBEmbeddingsMap &Embedder::getBBVecMap() const {
+ if (BBVecMap.empty())
+ computeEmbeddings();
+ return BBVecMap;
+}
+
+const Embedding &Embedder::getFunctionVector() const {
+ // Currently, we always (re)compute the embeddings for the function.
+ // This is cheaper than caching the vector.
+ computeEmbeddings();
+ return FuncVector;
+}
+
#define RETURN_LOOKUP_IF(CONDITION, KEY_STR) \
if (CONDITION) \
return lookupVocab(KEY_STR);
@@ -132,7 +151,7 @@ Embedding SymbolicEmbedder::getOperandEmbedding(const Value *Op) const {
#undef RETURN_LOOKUP_IF
-void SymbolicEmbedder::computeEmbeddings() {
+void SymbolicEmbedder::computeEmbeddings() const {
if (F.isDeclaration())
return;
for (const auto &BB : F) {
@@ -142,7 +161,7 @@ void SymbolicEmbedder::computeEmbeddings() {
}
}
-Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) {
+Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) const {
Embedding BBVector(Dimension, 0);
for (const auto &I : BB) {
@@ -271,7 +290,6 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
}
std::unique_ptr<Embedder> Emb = std::move(*EmbOrErr);
- Emb->computeEmbeddings();
OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
OS << "Function vector: ";
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the mutable
objects are essentially caches and marked that way given they are intended to be an internal implementation detail?
Yes, that's the idea. The internal data structures that hold embeddings kind of act like caches. |
5eab1c1
to
5cb422e
Compare
5cb422e
to
06b4d1b
Compare
…41811) Currently, users have to invoke two APIs: `computeEmbeddings()` followed by getters to access the embeddings. This PR refactors the code to reduce this *stateful* access of APIs. Users can now directly invoke getters; Internally, getters would compute the embeddings.
…41811) Currently, users have to invoke two APIs: `computeEmbeddings()` followed by getters to access the embeddings. This PR refactors the code to reduce this *stateful* access of APIs. Users can now directly invoke getters; Internally, getters would compute the embeddings.
Currently, users have to invoke two APIs:
computeEmbeddings()
followed by getters to access the embeddings. This PR refactors the code to reduce this stateful access of APIs. Users can now directly invoke getters; Internally, getters would compute the embeddings.(Tracking Issue: #141817)