-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[IR2Vec] Simplifying creation of Embedder #143999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[IR2Vec] Simplifying creation of Embedder #143999
Conversation
@llvm/pr-subscribers-mlgo @llvm/pr-subscribers-llvm-analysis Author: S. VenkataKeerthy (svkeerthy) ChangesThis change simplifies the API by removing the error handling complexity.
(Tracking issue - #141817) Full diff: https://github.com/llvm/llvm-project/pull/143999.diff 6 Files Affected:
diff --git a/llvm/docs/MLGO.rst b/llvm/docs/MLGO.rst
index 4f8fb3f59ca19..e7bba9995b75b 100644
--- a/llvm/docs/MLGO.rst
+++ b/llvm/docs/MLGO.rst
@@ -479,14 +479,9 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
// Assuming F is an llvm::Function&
// For example, using IR2VecKind::Symbolic:
- Expected<std::unique_ptr<ir2vec::Embedder>> EmbOrErr =
+ std::unique_ptr<ir2vec::Embedder> Emb =
ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
- if (auto Err = EmbOrErr.takeError()) {
- // Handle error in embedder creation
- return;
- }
- std::unique_ptr<ir2vec::Embedder> Emb = std::move(*EmbOrErr);
3. **Compute and Access Embeddings**:
Call ``getFunctionVector()`` to get the embedding for the function.
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index f1aaf4cd2e013..6efa6eac56af9 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -170,8 +170,8 @@ class Embedder {
virtual ~Embedder() = default;
/// Factory method to create an Embedder object.
- static Expected<std::unique_ptr<Embedder>>
- create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary);
+ static std::unique_ptr<Embedder> create(IR2VecKind Mode, const Function &F,
+ const Vocab &Vocabulary);
/// Returns a map containing instructions and the corresponding embeddings for
/// the function F if it has been computed. If not, it computes the embeddings
diff --git a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
index 29d3aaf46dc06..dd4eb7f0df053 100644
--- a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
+++ b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
@@ -204,16 +204,12 @@ void FunctionPropertiesInfo::updateForBB(const BasicBlock &BB,
// We instantiate the IR2Vec embedder each time, as having an unique
// pointer to the embedder as member of the class would make it
// non-copyable. Instantiating the embedder in itself is not costly.
- auto EmbOrErr = ir2vec::Embedder::create(IR2VecKind::Symbolic,
+ auto Embedder = ir2vec::Embedder::create(IR2VecKind::Symbolic,
*BB.getParent(), *IR2VecVocab);
- if (Error Err = EmbOrErr.takeError()) {
- handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
- BB.getContext().emitError("Error creating IR2Vec embeddings: " +
- EI.message());
- });
+ if (!Embedder) {
+ BB.getContext().emitError("Error creating IR2Vec embeddings");
return;
}
- auto Embedder = std::move(*EmbOrErr);
const auto &BBEmbedding = Embedder->getBBVector(BB);
// Subtract BBEmbedding from Function embedding if the direction is -1,
// and add it if the direction is +1.
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index f51d3252d6606..68026618449d8 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -123,13 +123,14 @@ Embedder::Embedder(const Function &F, const Vocab &Vocabulary)
Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight),
TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {}
-Expected<std::unique_ptr<Embedder>>
-Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
+std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
+ const Vocab &Vocabulary) {
switch (Mode) {
case IR2VecKind::Symbolic:
return std::make_unique<SymbolicEmbedder>(F, Vocabulary);
}
- return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
+ llvm_unreachable("Unknown IR2Vec kind");
+ return nullptr;
}
// FIXME: Currently lookups are string based. Use numeric Keys
@@ -389,17 +390,13 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
auto Vocab = IR2VecVocabResult.getVocabulary();
for (Function &F : M) {
- Expected<std::unique_ptr<Embedder>> EmbOrErr =
+ std::unique_ptr<Embedder> Emb =
Embedder::create(IR2VecKind::Symbolic, F, Vocab);
- if (auto Err = EmbOrErr.takeError()) {
- handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
- OS << "Error creating IR2Vec embeddings: " << EI.message() << "\n";
- });
+ if (!Emb) {
+ OS << "Error creating IR2Vec embeddings \n";
continue;
}
- std::unique_ptr<Embedder> Emb = std::move(*EmbOrErr);
-
OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
OS << "Function vector: ";
Emb->getFunctionVector().print(OS);
@@ -442,4 +439,4 @@ PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,
}
return PreservedAnalyses::all();
-}
\ No newline at end of file
+}
diff --git a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
index e50486bcbcb27..ca4f5d0f63026 100644
--- a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
+++ b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
@@ -127,10 +127,9 @@ class FunctionPropertiesAnalysisTest : public testing::Test {
}
std::unique_ptr<ir2vec::Embedder> createEmbedder(const Function &F) {
- auto EmbResult =
- ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
- EXPECT_TRUE(static_cast<bool>(EmbResult));
- return std::move(*EmbResult);
+ auto Emb = ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
+ EXPECT_TRUE(static_cast<bool>(Emb));
+ return std::move(Emb);
}
};
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index c3ed6e90cd8fc..05af55b59323b 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -216,10 +216,7 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) {
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
- EXPECT_TRUE(static_cast<bool>(Result));
-
- auto *Emb = Result->get();
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
EXPECT_NE(Emb, nullptr);
}
@@ -231,15 +228,16 @@ TEST(IR2VecTest, CreateInvalidMode) {
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
- // static_cast an invalid int to IR2VecKind
+// static_cast an invalid int to IR2VecKind
+#ifndef NDEBUG
+#if GTEST_HAS_DEATH_TEST
+ EXPECT_DEATH(Embedder::create(static_cast<IR2VecKind>(-1), *F, V),
+ "Unknown IR2Vec kind");
+#endif // GTEST_HAS_DEATH_TEST
+#else
auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V);
EXPECT_FALSE(static_cast<bool>(Result));
-
- std::string ErrMsg;
- llvm::handleAllErrors(
- Result.takeError(),
- [&](const llvm::ErrorInfoBase &EIB) { ErrMsg = EIB.message(); });
- EXPECT_NE(ErrMsg.find("Unknown IR2VecKind"), std::string::npos);
+#endif // NDEBUG
}
TEST(IR2VecTest, LookupVocab) {
@@ -298,10 +296,6 @@ class IR2VecTestFixture : public ::testing::Test {
Instruction *AddInst = nullptr;
Instruction *RetInst = nullptr;
- float OriginalOpcWeight = ::OpcWeight;
- float OriginalTypeWeight = ::TypeWeight;
- float OriginalArgWeight = ::ArgWeight;
-
void SetUp() override {
V = {{"add", {1.0, 2.0}},
{"integerTy", {0.25, 0.25}},
@@ -325,9 +319,8 @@ class IR2VecTestFixture : public ::testing::Test {
};
TEST_F(IR2VecTestFixture, GetInstVecMap) {
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
- ASSERT_TRUE(static_cast<bool>(Result));
- auto Emb = std::move(*Result);
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
const auto &InstMap = Emb->getInstVecMap();
@@ -348,9 +341,8 @@ TEST_F(IR2VecTestFixture, GetInstVecMap) {
}
TEST_F(IR2VecTestFixture, GetBBVecMap) {
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
- ASSERT_TRUE(static_cast<bool>(Result));
- auto Emb = std::move(*Result);
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
const auto &BBMap = Emb->getBBVecMap();
@@ -365,9 +357,8 @@ TEST_F(IR2VecTestFixture, GetBBVecMap) {
}
TEST_F(IR2VecTestFixture, GetBBVector) {
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
- ASSERT_TRUE(static_cast<bool>(Result));
- auto Emb = std::move(*Result);
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
const auto &BBVec = Emb->getBBVector(*BB);
@@ -377,9 +368,8 @@ TEST_F(IR2VecTestFixture, GetBBVector) {
}
TEST_F(IR2VecTestFixture, GetFunctionVector) {
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
- ASSERT_TRUE(static_cast<bool>(Result));
- auto Emb = std::move(*Result);
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
const auto &FuncVec = Emb->getFunctionVector();
|
ac378a9
to
2657262
Compare
cc133a1
to
1a051f1
Compare
2657262
to
730ab91
Compare
0d92141
to
d71dd50
Compare
d31d756
to
32d16aa
Compare
d71dd50
to
ea224df
Compare
32d16aa
to
a426f2c
Compare
ea224df
to
8b8932b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
a426f2c
to
05453a3
Compare
8b8932b
to
29ebe35
Compare
05453a3
to
2be999d
Compare
29ebe35
to
037cc63
Compare
cd4066f
to
f09b163
Compare
f09b163
to
f33b9e3
Compare
Merge activity
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/2/builds/27640 Here is the relevant piece of the build log for the reference
|
This change simplifies the API by removing the error handling complexity. - Changed `Embedder::create()` to return `std::unique_ptr<Embedder>` directly instead of `Expected<std::unique_ptr<Embedder>>` - Updated documentation and tests to reflect the new API - Added death test for invalid IR2Vec kind in debug mode - In release mode, simply returns nullptr for invalid kinds instead of creating an error (Tracking issue - llvm#141817)
This change simplifies the API by removing the error handling complexity. - Changed `Embedder::create()` to return `std::unique_ptr<Embedder>` directly instead of `Expected<std::unique_ptr<Embedder>>` - Updated documentation and tests to reflect the new API - Added death test for invalid IR2Vec kind in debug mode - In release mode, simply returns nullptr for invalid kinds instead of creating an error (Tracking issue - llvm#141817)
This change simplifies the API by removing the error handling complexity.
Embedder::create()
to returnstd::unique_ptr<Embedder>
directly instead ofExpected<std::unique_ptr<Embedder>>
(Tracking issue - #141817)