Skip to content

[IR2Vec] Simplifying creation of Embedder #143999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

svkeerthy
Copy link
Contributor

@svkeerthy svkeerthy commented Jun 12, 2025

This change simplifies the API by removing the error handling complexity.

  • Changed Embedder::create() to return std::unique_ptr<Embedder> directly instead of Expected<std::unique_ptr<Embedder>>
  • Updated documentation and tests to reflect the new API
  • Added death test for invalid IR2Vec kind in debug mode
  • In release mode, simply returns nullptr for invalid kinds instead of creating an error

(Tracking issue - #141817)

@svkeerthy svkeerthy changed the title Simplifying creation of Embedder [IR2Vec] Simplifying creation of Embedder Jun 12, 2025
Copy link
Contributor Author

@albertcohen

@svkeerthy svkeerthy marked this pull request as ready for review June 12, 2025 23:57
@llvmbot llvmbot added mlgo llvm:analysis Includes value tracking, cost tables and constant folding labels Jun 12, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 12, 2025

@llvm/pr-subscribers-mlgo

@llvm/pr-subscribers-llvm-analysis

Author: S. VenkataKeerthy (svkeerthy)

Changes

This change simplifies the API by removing the error handling complexity.

  • Changed Embedder::create() to return std::unique_ptr&lt;Embedder&gt; directly instead of Expected&lt;std::unique_ptr&lt;Embedder&gt;&gt;
  • Updated documentation and tests to reflect the new API
  • Added death test for invalid IR2Vec kind in debug mode
  • In release mode, simply returns nullptr for invalid kinds instead of creating an error

(Tracking issue - #141817)


Full diff: https://github.com/llvm/llvm-project/pull/143999.diff

6 Files Affected:

  • (modified) llvm/docs/MLGO.rst (+1-6)
  • (modified) llvm/include/llvm/Analysis/IR2Vec.h (+2-2)
  • (modified) llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp (+3-7)
  • (modified) llvm/lib/Analysis/IR2Vec.cpp (+8-11)
  • (modified) llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp (+3-4)
  • (modified) llvm/unittests/Analysis/IR2VecTest.cpp (+17-27)
diff --git a/llvm/docs/MLGO.rst b/llvm/docs/MLGO.rst
index 4f8fb3f59ca19..e7bba9995b75b 100644
--- a/llvm/docs/MLGO.rst
+++ b/llvm/docs/MLGO.rst
@@ -479,14 +479,9 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
 
       // Assuming F is an llvm::Function&
       // For example, using IR2VecKind::Symbolic:
-      Expected<std::unique_ptr<ir2vec::Embedder>> EmbOrErr =
+      std::unique_ptr<ir2vec::Embedder> Emb =
           ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
 
-      if (auto Err = EmbOrErr.takeError()) {
-        // Handle error in embedder creation
-        return;
-      }
-      std::unique_ptr<ir2vec::Embedder> Emb = std::move(*EmbOrErr);
 
 3. **Compute and Access Embeddings**:
    Call ``getFunctionVector()`` to get the embedding for the function. 
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index f1aaf4cd2e013..6efa6eac56af9 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -170,8 +170,8 @@ class Embedder {
   virtual ~Embedder() = default;
 
   /// Factory method to create an Embedder object.
-  static Expected<std::unique_ptr<Embedder>>
-  create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary);
+  static std::unique_ptr<Embedder> create(IR2VecKind Mode, const Function &F,
+                                          const Vocab &Vocabulary);
 
   /// Returns a map containing instructions and the corresponding embeddings for
   /// the function F if it has been computed. If not, it computes the embeddings
diff --git a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
index 29d3aaf46dc06..dd4eb7f0df053 100644
--- a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
+++ b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
@@ -204,16 +204,12 @@ void FunctionPropertiesInfo::updateForBB(const BasicBlock &BB,
     // We instantiate the IR2Vec embedder each time, as having an unique
     // pointer to the embedder as member of the class would make it
     // non-copyable. Instantiating the embedder in itself is not costly.
-    auto EmbOrErr = ir2vec::Embedder::create(IR2VecKind::Symbolic,
+    auto Embedder = ir2vec::Embedder::create(IR2VecKind::Symbolic,
                                              *BB.getParent(), *IR2VecVocab);
-    if (Error Err = EmbOrErr.takeError()) {
-      handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
-        BB.getContext().emitError("Error creating IR2Vec embeddings: " +
-                                  EI.message());
-      });
+    if (!Embedder) {
+      BB.getContext().emitError("Error creating IR2Vec embeddings");
       return;
     }
-    auto Embedder = std::move(*EmbOrErr);
     const auto &BBEmbedding = Embedder->getBBVector(BB);
     // Subtract BBEmbedding from Function embedding if the direction is -1,
     // and add it if the direction is +1.
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index f51d3252d6606..68026618449d8 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -123,13 +123,14 @@ Embedder::Embedder(const Function &F, const Vocab &Vocabulary)
       Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight),
       TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {}
 
-Expected<std::unique_ptr<Embedder>>
-Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
+std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
+                                           const Vocab &Vocabulary) {
   switch (Mode) {
   case IR2VecKind::Symbolic:
     return std::make_unique<SymbolicEmbedder>(F, Vocabulary);
   }
-  return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
+  llvm_unreachable("Unknown IR2Vec kind");
+  return nullptr;
 }
 
 // FIXME: Currently lookups are string based. Use numeric Keys
@@ -389,17 +390,13 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
 
   auto Vocab = IR2VecVocabResult.getVocabulary();
   for (Function &F : M) {
-    Expected<std::unique_ptr<Embedder>> EmbOrErr =
+    std::unique_ptr<Embedder> Emb =
         Embedder::create(IR2VecKind::Symbolic, F, Vocab);
-    if (auto Err = EmbOrErr.takeError()) {
-      handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
-        OS << "Error creating IR2Vec embeddings: " << EI.message() << "\n";
-      });
+    if (!Emb) {
+      OS << "Error creating IR2Vec embeddings \n";
       continue;
     }
 
-    std::unique_ptr<Embedder> Emb = std::move(*EmbOrErr);
-
     OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
     OS << "Function vector: ";
     Emb->getFunctionVector().print(OS);
@@ -442,4 +439,4 @@ PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,
   }
 
   return PreservedAnalyses::all();
-}
\ No newline at end of file
+}
diff --git a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
index e50486bcbcb27..ca4f5d0f63026 100644
--- a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
+++ b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
@@ -127,10 +127,9 @@ class FunctionPropertiesAnalysisTest : public testing::Test {
   }
 
   std::unique_ptr<ir2vec::Embedder> createEmbedder(const Function &F) {
-    auto EmbResult =
-        ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
-    EXPECT_TRUE(static_cast<bool>(EmbResult));
-    return std::move(*EmbResult);
+    auto Emb = ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
+    EXPECT_TRUE(static_cast<bool>(Emb));
+    return std::move(Emb);
   }
 };
 
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index c3ed6e90cd8fc..05af55b59323b 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -216,10 +216,7 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) {
   FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
   Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
 
-  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
-  EXPECT_TRUE(static_cast<bool>(Result));
-
-  auto *Emb = Result->get();
+  auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
   EXPECT_NE(Emb, nullptr);
 }
 
@@ -231,15 +228,16 @@ TEST(IR2VecTest, CreateInvalidMode) {
   FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
   Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
 
-  // static_cast an invalid int to IR2VecKind
+// static_cast an invalid int to IR2VecKind
+#ifndef NDEBUG
+#if GTEST_HAS_DEATH_TEST
+  EXPECT_DEATH(Embedder::create(static_cast<IR2VecKind>(-1), *F, V),
+               "Unknown IR2Vec kind");
+#endif // GTEST_HAS_DEATH_TEST
+#else
   auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V);
   EXPECT_FALSE(static_cast<bool>(Result));
-
-  std::string ErrMsg;
-  llvm::handleAllErrors(
-      Result.takeError(),
-      [&](const llvm::ErrorInfoBase &EIB) { ErrMsg = EIB.message(); });
-  EXPECT_NE(ErrMsg.find("Unknown IR2VecKind"), std::string::npos);
+#endif // NDEBUG
 }
 
 TEST(IR2VecTest, LookupVocab) {
@@ -298,10 +296,6 @@ class IR2VecTestFixture : public ::testing::Test {
   Instruction *AddInst = nullptr;
   Instruction *RetInst = nullptr;
 
-  float OriginalOpcWeight = ::OpcWeight;
-  float OriginalTypeWeight = ::TypeWeight;
-  float OriginalArgWeight = ::ArgWeight;
-
   void SetUp() override {
     V = {{"add", {1.0, 2.0}},
          {"integerTy", {0.25, 0.25}},
@@ -325,9 +319,8 @@ class IR2VecTestFixture : public ::testing::Test {
 };
 
 TEST_F(IR2VecTestFixture, GetInstVecMap) {
-  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
-  ASSERT_TRUE(static_cast<bool>(Result));
-  auto Emb = std::move(*Result);
+  auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Emb));
 
   const auto &InstMap = Emb->getInstVecMap();
 
@@ -348,9 +341,8 @@ TEST_F(IR2VecTestFixture, GetInstVecMap) {
 }
 
 TEST_F(IR2VecTestFixture, GetBBVecMap) {
-  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
-  ASSERT_TRUE(static_cast<bool>(Result));
-  auto Emb = std::move(*Result);
+  auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Emb));
 
   const auto &BBMap = Emb->getBBVecMap();
 
@@ -365,9 +357,8 @@ TEST_F(IR2VecTestFixture, GetBBVecMap) {
 }
 
 TEST_F(IR2VecTestFixture, GetBBVector) {
-  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
-  ASSERT_TRUE(static_cast<bool>(Result));
-  auto Emb = std::move(*Result);
+  auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Emb));
 
   const auto &BBVec = Emb->getBBVector(*BB);
 
@@ -377,9 +368,8 @@ TEST_F(IR2VecTestFixture, GetBBVector) {
 }
 
 TEST_F(IR2VecTestFixture, GetFunctionVector) {
-  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
-  ASSERT_TRUE(static_cast<bool>(Result));
-  auto Emb = std::move(*Result);
+  auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Emb));
 
   const auto &FuncVec = Emb->getFunctionVector();
 

@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-12-_ir2vec_scale_vocab branch from ac378a9 to 2657262 Compare June 13, 2025 00:01
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-12-simplifying_creation_of_embedder branch from cc133a1 to 1a051f1 Compare June 13, 2025 00:01
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-12-_ir2vec_scale_vocab branch from 2657262 to 730ab91 Compare June 13, 2025 17:46
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-12-simplifying_creation_of_embedder branch 2 times, most recently from 0d92141 to d71dd50 Compare June 13, 2025 18:18
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-12-_ir2vec_scale_vocab branch 2 times, most recently from d31d756 to 32d16aa Compare June 17, 2025 18:01
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-12-simplifying_creation_of_embedder branch from d71dd50 to ea224df Compare June 17, 2025 18:01
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-12-_ir2vec_scale_vocab branch from 32d16aa to a426f2c Compare June 20, 2025 23:29
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-12-simplifying_creation_of_embedder branch from ea224df to 8b8932b Compare June 20, 2025 23:29
Copy link
Member

@mtrofin mtrofin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-12-_ir2vec_scale_vocab branch from a426f2c to 05453a3 Compare June 23, 2025 21:09
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-12-simplifying_creation_of_embedder branch from 8b8932b to 29ebe35 Compare June 23, 2025 21:10
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-12-_ir2vec_scale_vocab branch from 05453a3 to 2be999d Compare June 30, 2025 20:56
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-12-simplifying_creation_of_embedder branch from 29ebe35 to 037cc63 Compare June 30, 2025 20:56
Base automatically changed from users/svkeerthy/06-12-_ir2vec_scale_vocab to main June 30, 2025 21:09
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-12-simplifying_creation_of_embedder branch 2 times, most recently from cd4066f to f09b163 Compare July 1, 2025 01:11
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-12-simplifying_creation_of_embedder branch from f09b163 to f33b9e3 Compare July 1, 2025 01:20
Copy link
Contributor Author

svkeerthy commented Jul 1, 2025

Merge activity

  • Jul 1, 1:22 AM UTC: A user started a stack merge that includes this pull request via Graphite.
  • Jul 1, 1:24 AM UTC: @svkeerthy merged this pull request with Graphite.

@svkeerthy svkeerthy merged commit 9438048 into main Jul 1, 2025
6 of 8 checks passed
@svkeerthy svkeerthy deleted the users/svkeerthy/06-12-simplifying_creation_of_embedder branch July 1, 2025 01:24
@llvm-ci
Copy link
Collaborator

llvm-ci commented Jul 1, 2025

LLVM Buildbot has detected a new failure on builder llvm-clang-x86_64-win-fast running on as-builder-3 while building llvm at step 7 "test-build-unified-tree-check-llvm-unit".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/2/builds/27640

Here is the relevant piece of the build log for the reference
Step 7 (test-build-unified-tree-check-llvm-unit) failure: test (failure)
******************** TEST 'LLVM-Unit :: Analysis/./AnalysisTests.exe/24/86' FAILED ********************
Script(shard):
--
GTEST_OUTPUT=json:C:\buildbot\as-builder-3\llvm-clang-x86_64-win-fast\build\unittests\Analysis\.\AnalysisTests.exe-LLVM-Unit-8584-24-86.json GTEST_SHUFFLE=0 GTEST_TOTAL_SHARDS=86 GTEST_SHARD_INDEX=24 C:\buildbot\as-builder-3\llvm-clang-x86_64-win-fast\build\unittests\Analysis\.\AnalysisTests.exe
--

Script:
--
C:\buildbot\as-builder-3\llvm-clang-x86_64-win-fast\build\unittests\Analysis\.\AnalysisTests.exe --gtest_filter=IR2VecTest.CreateInvalidMode
--
C:\buildbot\as-builder-3\llvm-clang-x86_64-win-fast\llvm-project\llvm\unittests\Analysis\IR2VecTest.cpp(239): error: Value of: static_cast<bool>(Result)
  Actual: true
Expected: false


C:\buildbot\as-builder-3\llvm-clang-x86_64-win-fast\llvm-project\llvm\unittests\Analysis\IR2VecTest.cpp:239
Value of: static_cast<bool>(Result)
  Actual: true
Expected: false



********************


rlavaee pushed a commit to rlavaee/llvm-project that referenced this pull request Jul 1, 2025
This change simplifies the API by removing the error handling complexity. 

- Changed `Embedder::create()` to return `std::unique_ptr<Embedder>` directly instead of `Expected<std::unique_ptr<Embedder>>`
- Updated documentation and tests to reflect the new API
- Added death test for invalid IR2Vec kind in debug mode
- In release mode, simply returns nullptr for invalid kinds instead of creating an error

(Tracking issue - llvm#141817)
rlavaee pushed a commit to rlavaee/llvm-project that referenced this pull request Jul 1, 2025
This change simplifies the API by removing the error handling complexity. 

- Changed `Embedder::create()` to return `std::unique_ptr<Embedder>` directly instead of `Expected<std::unique_ptr<Embedder>>`
- Updated documentation and tests to reflect the new API
- Added death test for invalid IR2Vec kind in debug mode
- In release mode, simply returns nullptr for invalid kinds instead of creating an error

(Tracking issue - llvm#141817)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding mlgo
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants