Skip to content

[NFC][IR2Vec] Refactoring for Stateless Embedding Computation #141811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 28, 2025

Conversation

svkeerthy
Copy link
Contributor

@svkeerthy svkeerthy commented May 28, 2025

Currently, users have to invoke two APIs: computeEmbeddings() followed by getters to access the embeddings. This PR refactors the code to reduce this stateful access of APIs. Users can now directly invoke getters; Internally, getters would compute the embeddings.

(Tracking Issue: #141817)

Copy link
Contributor Author

This stack of pull requests is managed by Graphite. Learn more about stacking.

@svkeerthy svkeerthy changed the title Reducing state [NFC][IR2Vec] Reducing state May 28, 2025
@svkeerthy svkeerthy changed the title [NFC][IR2Vec] Reducing state [NFC][IR2Vec] Refactoring for Stateless Embedding Computation May 28, 2025
@svkeerthy svkeerthy force-pushed the users/svkeerthy/05-28-reducing_state branch 3 times, most recently from 4023ab7 to 5eab1c1 Compare May 28, 2025 18:10
@svkeerthy svkeerthy marked this pull request as ready for review May 28, 2025 18:10
@llvmbot llvmbot added mlgo llvm:analysis Includes value tracking, cost tables and constant folding labels May 28, 2025
@llvmbot
Copy link
Member

llvmbot commented May 28, 2025

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-mlgo

Author: S. VenkataKeerthy (svkeerthy)

Changes

Currently, users have to invoke two APIs: computeEmbeddings() followed by getters to access the embeddings. This PR refactors the code to reduce this stateful access of APIs. Users can now directly invoke getters; Internally, getters would compute the embeddings.


Full diff: https://github.com/llvm/llvm-project/pull/141811.diff

3 Files Affected:

  • (modified) llvm/docs/MLGO.rst (+12-15)
  • (modified) llvm/include/llvm/Analysis/IR2Vec.h (+21-16)
  • (modified) llvm/lib/Analysis/IR2Vec.cpp (+21-3)
diff --git a/llvm/docs/MLGO.rst b/llvm/docs/MLGO.rst
index fa4b02cb11be7..549d8369d648d 100644
--- a/llvm/docs/MLGO.rst
+++ b/llvm/docs/MLGO.rst
@@ -490,24 +490,21 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
       std::unique_ptr<ir2vec::Embedder> Emb = std::move(*EmbOrErr);
 
 3. **Compute and Access Embeddings**:
-   Call ``computeEmbeddings()`` on the embedder instance to compute the 
-   embeddings. Then the embeddings can be accessed using different getter 
-   methods. Currently, ``Embedder`` can generate embeddings at three levels:
-   Instructions, Basic Blocks, and Functions.
+   Call ``getFunctionVector()`` to get the embedding for the function. 
 
-   .. code-block:: c++
+  .. code-block:: c++
 
-      Emb->computeEmbeddings();
       const ir2vec::Embedding &FuncVector = Emb->getFunctionVector();
-      const ir2vec::InstEmbeddingsMap &InstVecMap = Emb->getInstVecMap();
-      const ir2vec::BBEmbeddingsMap &BBVecMap = Emb->getBBVecMap();
-
-      // Example: Iterate over instruction embeddings
-      for (const auto &Entry : InstVecMap) {
-        const Instruction *Inst = Entry.getFirst();
-        const ir2vec::Embedding &InstEmbedding = Entry.getSecond();
-        // Use Inst and InstEmbedding
-      }
+
+   Currently, ``Embedder`` can generate embeddings at three levels: Instructions,
+   Basic Blocks, and Functions. Appropriate getters are provided to access the
+   embeddings at these levels.
+
+  .. note::
+
+    The validity of ``Embedder`` instance (and the embeddings it generates) is
+    tied to the function it is associated with remains unchanged. If the function
+    is modified, the embeddings may become stale and should be recomputed accordingly.
 
 4. **Working with Embeddings:**
    Embeddings are represented as ``std::vector<double>``. These
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 3063040093402..43c95c5e89aed 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -80,12 +80,17 @@ class Embedder {
 
   // Utility maps - these are used to store the vector representations of
   // instructions, basic blocks and functions.
-  Embedding FuncVector;
-  BBEmbeddingsMap BBVecMap;
-  InstEmbeddingsMap InstVecMap;
+  mutable Embedding FuncVector;
+  mutable BBEmbeddingsMap BBVecMap;
+  mutable InstEmbeddingsMap InstVecMap;
 
   Embedder(const Function &F, const Vocab &Vocabulary, unsigned Dimension);
 
+  /// Helper function to compute embeddings. It generates embeddings for all
+  /// the instructions and basic blocks in the function F. Logic of computing
+  /// the embeddings is specific to the kind of embeddings being computed.
+  virtual void computeEmbeddings() const = 0;
+
   /// Lookup vocabulary for a given Key. If the key is not found, it returns a
   /// zero vector.
   Embedding lookupVocab(const std::string &Key) const;
@@ -100,25 +105,24 @@ class Embedder {
 public:
   virtual ~Embedder() = default;
 
-  /// Top level function to compute embeddings. It generates embeddings for all
-  /// the instructions and basic blocks in the function F. Logic of computing
-  /// the embeddings is specific to the kind of embeddings being computed.
-  virtual void computeEmbeddings() = 0;
-
   /// Factory method to create an Embedder object.
   static Expected<std::unique_ptr<Embedder>> create(IR2VecKind Mode,
                                                     const Function &F,
                                                     const Vocab &Vocabulary,
                                                     unsigned Dimension);
 
-  /// Returns a map containing instructions and the corresponding embeddings.
-  const InstEmbeddingsMap &getInstVecMap() const { return InstVecMap; }
+  /// Returns a map containing instructions and the corresponding embeddings for
+  /// the function F if it has been computed. If not, it computes the embeddings
+  /// for the function and returns the map.
+  const InstEmbeddingsMap &getInstVecMap() const;
 
-  /// Returns a map containing basic block and the corresponding embeddings.
-  const BBEmbeddingsMap &getBBVecMap() const { return BBVecMap; }
+  /// Returns a map containing basic block and the corresponding embeddings for
+  /// the function F if it has been computed. If not, it computes the embeddings
+  /// for the function and returns the map.
+  const BBEmbeddingsMap &getBBVecMap() const;
 
-  /// Returns the embedding for the current function.
-  const Embedding &getFunctionVector() const { return FuncVector; }
+  /// Computes and returns the embedding for the current function.
+  const Embedding &getFunctionVector() const;
 };
 
 /// Class for computing the Symbolic embeddings of IR2Vec.
@@ -127,7 +131,7 @@ class Embedder {
 class SymbolicEmbedder : public Embedder {
 private:
   /// Utility function to compute the embedding for a given basic block.
-  Embedding computeBB2Vec(const BasicBlock &BB);
+  Embedding computeBB2Vec(const BasicBlock &BB) const;
 
   /// Utility function to compute the embedding for a given type.
   Embedding getTypeEmbedding(const Type *Ty) const;
@@ -135,13 +139,14 @@ class SymbolicEmbedder : public Embedder {
   /// Utility function to compute the embedding for a given operand.
   Embedding getOperandEmbedding(const Value *Op) const;
 
+  void computeEmbeddings() const override;
+
 public:
   SymbolicEmbedder(const Function &F, const Vocab &Vocabulary,
                    unsigned Dimension)
       : Embedder(F, Vocabulary, Dimension) {
     FuncVector = Embedding(Dimension, 0);
   }
-  void computeEmbeddings() override;
 };
 
 } // namespace ir2vec
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index cc419c84e9881..5f3114dcdeeaa 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -103,6 +103,25 @@ Embedding Embedder::lookupVocab(const std::string &Key) const {
   return Vec;
 }
 
+const InstEmbeddingsMap &Embedder::getInstVecMap() const {
+  if (InstVecMap.empty())
+    computeEmbeddings();
+  return InstVecMap;
+}
+
+const BBEmbeddingsMap &Embedder::getBBVecMap() const {
+  if (BBVecMap.empty())
+    computeEmbeddings();
+  return BBVecMap;
+}
+
+const Embedding &Embedder::getFunctionVector() const {
+  // Currently, we always (re)compute the embeddings for the function.
+  // This is cheaper than caching the vector.
+  computeEmbeddings();
+  return FuncVector;
+}
+
 #define RETURN_LOOKUP_IF(CONDITION, KEY_STR)                                   \
   if (CONDITION)                                                               \
     return lookupVocab(KEY_STR);
@@ -132,7 +151,7 @@ Embedding SymbolicEmbedder::getOperandEmbedding(const Value *Op) const {
 
 #undef RETURN_LOOKUP_IF
 
-void SymbolicEmbedder::computeEmbeddings() {
+void SymbolicEmbedder::computeEmbeddings() const {
   if (F.isDeclaration())
     return;
   for (const auto &BB : F) {
@@ -142,7 +161,7 @@ void SymbolicEmbedder::computeEmbeddings() {
   }
 }
 
-Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) {
+Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) const {
   Embedding BBVector(Dimension, 0);
 
   for (const auto &I : BB) {
@@ -271,7 +290,6 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
     }
 
     std::unique_ptr<Embedder> Emb = std::move(*EmbOrErr);
-    Emb->computeEmbeddings();
 
     OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
     OS << "Function vector: ";

Copy link
Contributor

@boomanaiden154 boomanaiden154 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the mutable objects are essentially caches and marked that way given they are intended to be an internal implementation detail?

Copy link
Contributor Author

Yes, that's the idea. The internal data structures that hold embeddings kind of act like caches.

@svkeerthy svkeerthy force-pushed the users/svkeerthy/05-28-reducing_state branch from 5eab1c1 to 5cb422e Compare May 28, 2025 19:11
@svkeerthy svkeerthy force-pushed the users/svkeerthy/05-28-reducing_state branch from 5cb422e to 06b4d1b Compare May 28, 2025 19:18
@svkeerthy svkeerthy merged commit 3581e9b into main May 28, 2025
6 of 7 checks passed
@svkeerthy svkeerthy deleted the users/svkeerthy/05-28-reducing_state branch May 28, 2025 19:19
google-yfyang pushed a commit to google-yfyang/llvm-project that referenced this pull request May 29, 2025
…41811)

Currently, users have to invoke two APIs: `computeEmbeddings()` followed
by getters to access the embeddings. This PR refactors the code to
reduce this *stateful* access of APIs. Users can now directly invoke
getters; Internally, getters would compute the embeddings.
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Jun 3, 2025
…41811)

Currently, users have to invoke two APIs: `computeEmbeddings()` followed
by getters to access the embeddings. This PR refactors the code to
reduce this *stateful* access of APIs. Users can now directly invoke
getters; Internally, getters would compute the embeddings.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding mlgo
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants